easydel.layers.attention

Contents

easydel.layers.attention#

Flexible attention module for various attention mechanisms.

Provides a unified interface for different attention implementations, automatically selecting the optimal mechanism based on hardware and configuration. Supports Flash Attention, Ring Attention, Splash Attention, and other optimized implementations.

Classes:

AttentionMechanisms: Enum of available attention mechanisms FlexibleAttentionModule: Main attention module with automatic optimization

Functions:

tpu_version_check: Check TPU version for optimization get_optimal_config: Determine best attention mechanism for hardware _get_jax_dtype_from_string: Convert string to JAX dtype

Constants:

DEFAULT_ATTENTION_MECHANISM: Default attention mechanism (“auto”)

Example

>>> from easydel.layers.attention import FlexibleAttentionModule
>>> attn = FlexibleAttentionModule(
...     config=config,
...     dtype=jnp.bfloat16,
...     attention_mechanism="flash_attn2"
... )
>>> output = attn(
...     query, key, value,
...     attention_mask=mask
... )
class easydel.layers.attention.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Available attention mechanism implementations.

Enumeration of different attention computation strategies, each optimized for specific hardware or use cases.

AUTO#

Automatically selects best mechanism for hardware.

Type

str

FLASH_ATTN2#

FlashAttention-2 for efficient GPU computation.

Type

str

RING#

RingAttention for sequence parallelism.

Type

str

VANILLA#

Standard dot-product attention.

Type

str

SPLASH#

SplashAttention optimized for TPUs.

Type

str

CUDNN#

cuDNN implementation for NVIDIA GPUs.

Type

str

BLOCKWISE#

Blockwise computation for memory efficiency.

Type

str

SDPA#

Scaled Dot Product Attention (JAX native).

Type

str

CUDA_FLASH_ATTN2#

CUDA-specific FlashAttention-2.

Type

str

RAGGED_PAGE_ATTENTION_V3#

Paged attention for efficient inference.

Type

str

RAGGED_PAGE_ATTENTION_V2#

Paged attention for efficient inference.

Type

str

REGRESSIVE_DECODE#

Optimized autoregressive decoding.

Type

str

AUTO: str = 'auto'#
BLOCKSPARSE: str = 'blocksparse'#
BLOCKWISE: str = 'blockwise'#
CUDA_FLASH_ATTN2: str = 'cuda_flash_attn2'#
CUDNN: str = 'cudnn'#
FLASH_ATTN2: str = 'flash_attn2'#
PAGED_ATTENTION: str = 'page_attention'#
RAGGED_PAGE_ATTENTION_V2: str = 'ragged_page_attention_v2'#
RAGGED_PAGE_ATTENTION_V3: str = 'ragged_page_attention_v3'#
REGRESSIVE_DECODE: str = 'autoregressive_decodeattn'#
RING: str = 'ring'#
SDPA: str = 'sdpa'#
SPLASH: str = 'blocksparse'#
VANILLA: str = 'vanilla'#
class easydel.layers.attention.AttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, Generic[Cfg]

Base class for Flax attention modules in EasyDeL, providing common utilities.

This class offers helper functions and attributes commonly needed by attention implementations within Flax, such as handling KV caching, sharding, mask manipulation, and head manipulation. Concrete attention implementations often inherit from this class.

config#

Configuration object for the attention module.

Type

Cfg | EasyDeLBaseConfig

cached_key#

Flax Cache for storing past key states (wont be used).

Type

nn.Cache[Array] | None

cached_value#

Flax Cache for storing past value states (wont be used).

Type

nn.Cache[Array] | None

cache_index#

Flax Cache for tracking the current index in the cache (wont be used).

Type

nn.Cache[Array] | None

static apply_complex_rotary(xq: Float[Array, '... seq heads dim'], xk: Float[Array, '... seq heads dim'], freqs_cis: Complex[Array, 'batch seq 1 dim_2']) tuple[jaxtyping.Float[Array, '... seq heads dim'], jaxtyping.Float[Array, '... seq heads dim']][source]#
apply_qk_shardings(q: Float[Array, 'batch seq heads dim'], k: Float[Array, 'batch seq heads dim']) tuple[jaxtyping.Float[Array, 'batch seq heads dim'], jaxtyping.Float[Array, 'batch seq heads dim']][source]#
apply_qkv_shardings(q: Float[Array, 'batch seq heads dim'], k: Float[Array, 'batch seq heads dim'], v: Float[Array, 'batch seq heads dim']) tuple[jaxtyping.Float[Array, 'batch seq heads dim'], jaxtyping.Float[Array, 'batch seq heads dim'], jaxtyping.Float[Array, 'batch seq heads dim']][source]#
static build_cache_pos(attention_mask: Bool[Array, 'batch heads seq_q seq_k'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: TransformerCacheView = None) Int[Array, 'batch seq'][source]#

Calculates the position indices within the sequence for cache-aware operations.

Parameters
  • attention_mask (jax.Array) – The attention mask (typically [batch, heads, q_len, k_len]).

  • mode (common_types.RUNTIME_MODE_TYPES) – The runtime mode.

  • cache_view (TransformerCacheView, optional) – The current KV cache view. Defaults to None.

Returns

An array representing the position of each token in the sequence,

adjusted by the cache index if provided. Shape usually [batch, q_len].

Return type

jax.Array

concatenate(*, query: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], key: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], value: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], mask_info: ~ejkernel.types.mask.MaskInfo, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], ~eformer.common_types._Empty] = <eformer.common_types._Empty object>, cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, sliding_window: int | None = None) tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Callable[[], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]], easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None, easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None][source]#

Prepares inputs for attention calculation, handling KV caching and mask merging.

This function combines the current query, key, and value with cached states (if applicable), merges various masks (attention, causal, FCM, sliding window), and returns the final key, value, attention mask, and a function to initialize the attention bias.

Parameters
  • query (Array) – Current query states [Batch, q_len, Heads, Dim].

  • key (Array) – Current key states [Batch, kv_len, Heads, Dim].

  • value (Array) – Current value states [Batch, kv_len, Heads, Dim].

  • attention_mask (Array) – Base attention mask (e.g., padding mask) [Batch, kv_len] or compatible.

  • mode (common_types.RUNTIME_MODE_TYPES) – The runtime mode (TRAIN, PREFILL, DECODE). Required.

  • cache_view (tp.Optional[TransformerCacheView | RaggedPagesCacheView], optional) – View into the KV cache. If None, caching is disabled. Defaults to None.

  • cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata], optional) – Cache metadata. Defaults to None.

  • causal_mask (tp.Optional[Array], optional) – Causal mask [1, 1, q_len, kv_len]. Defaults to None. Defaults to None.

  • fcm_mask (tp.Optional[Array], optional) – Fused-Context-Mask [Batch, 1, q_len, kv_len]. Defaults to None.

  • sliding_window (tp.Optional[int], optional) – Size of the sliding attention window. If None, not applied. Defaults to None.

Returns

  • key_states (Array): Final key states (potentially from cache).

  • value_states (Array): Final value states (potentially from cache).

  • attention_mask (Array): The final combined attention mask [Batch, Heads, q_len, kv_len].

  • init_attention_bias (Callable): Function to create the attention bias tensor.

  • updated_cache_view: The updated cache view (or None if no cache).

  • updated_cache_metadata: The updated cache metadata (or None if no metadata).

Return type

tp.Tuple[Array, Array, Array, tp.Callable[[], Array], tp.Optional[tp.Union[TransformerCacheView, RaggedPagesCacheView]]]

Raises

ValueError – If shapes are mismatched.

property default_key_value_sharding: NamedSharding#

Defines the default JAX sharding for key and value tensors.

Uses the partition specifications defined in the configuration’s partition_axis.

Returns

The default sharding configuration for K/V tensors.

Return type

NamedSharding

get_sharding_safely(tensor: Float[Array, '...']) PartitionSpec[source]#

Retrieves the PartitionSpec of a tensor, falling back to the default KV sharding.

Parameters

tensor (jax.Array) – The tensor whose sharding spec is needed.

Returns

The sharding specification of the tensor.

Return type

PartitionSpec

property quantizer: EasyQuantizer#

Provides an EasyQuantizer instance based on the module’s configuration.

Used for quantizing KV cache entries if enabled in the config.

Returns

The quantizer instance.

Return type

EasyQuantizer

static repeat_key_value(key: Float[Array, 'batch seq kv_heads dim'], value: Float[Array, 'batch seq kv_heads dim'], num_reps: int) tuple[jaxtyping.Float[Array, 'batch seq heads dim'], jaxtyping.Float[Array, 'batch seq heads dim']][source]#

Repeats key and value tensors for Grouped Query Attention (GQA).

Expands the head dimension by repeating num_reps times. Uses einops for concise repetition.

Parameters
  • key (Array) – Key tensor [Batch, Seq, NumKVHeads, Dim].

  • value (Array) – Value tensor [Batch, Seq, NumKVHeads, Dim].

  • num_reps (int) – The number of times to repeat each KV head (num_attention_heads / num_kv_heads).

Returns

Repeated key and value tensors, each with shape

[Batch, Seq, NumKVHeads * num_reps, Dim].

Return type

tp.Tuple[Array, Array]

shard_attention_prod(attn_output: Float[Array, 'batch seq heads dim']) Float[Array, 'batch seq heads dim'][source]#

Applies sharding constraints to the attention output tensor.

This is typically done before projecting the attention output back to the hidden dimension size.

Parameters

attn_output (jax.Array) – The output from the attention mechanism, usually with shape [Batch, SeqLen, NumHeads * DimPerHead].

Returns

The input tensor with applied sharding constraints based on the config.

Return type

jax.Array

easydel.layers.attention.Cfg#

Type variable for configuration objects.

alias of TypeVar(‘Cfg’, bound=EasyDeLBaseConfig)

class easydel.layers.attention.FlexibleAttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Unified interface for various attention mechanisms.

Central hub for managing different attention implementations, automatically selecting and executing the optimal mechanism based on hardware, configuration, and runtime requirements.

Supports optimized implementations like FlashAttention, SplashAttention, RingAttention, and standard dot-product attention. Provides automatic hardware detection and optimization selection.

config#

Model configuration with attention parameters.

dtype#

Data type for computations.

param_dtype#

Data type for parameters.

precision#

Precision setting for operations.

attention_mechanism#

Selected attention mechanism.

mesh#

JAX mesh for distributed computation.

implementation#

Concrete attention implementation.

Key Features:

  • Attention Mechanism Selection: Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints.

  • Sharding and Partitioning: Integrates with JAX’s sharding capabilities, enabling efficient distribution of computations and data across multiple devices.

  • Block-wise Computation: Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models.

  • Performance Optimization: Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration.

  • Flexibility and Customization: Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases.

  • Testing and Evaluation: Includes a run_attention_benchmarks method to systematically evaluate different attention mechanisms and help users identify the best-performing option.

The AttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX’s sharding capabilities and offering performance enhancements through specialized implementations like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options makes it adaptable to a variety of model architectures and hardware configurations. .. attribute:: impl

The chosen attention implementation backend instance.

type

AttentionBackend

deterministic#

Flag indicating whether dropout should be applied (False) or not (True). Currently hardcoded to True.

Type

bool

metadata#

Metadata derived from the configuration, used by the backend.

Type

OperationMetadata

forward(query_states: Float[Array, 'batch seq_q heads dim'], key_states: Float[Array, 'batch seq_k heads dim'], value_states: Float[Array, 'batch seq_v heads dim'], mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch heads seq_q seq_k'] | None = None, sliding_window: int | tuple[int, int] | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, init_bias: Optional[Callable[[], Float[Array, 'batch heads seq_q seq_k']]] = None, causal: bool = True, softmax_aux: jaxtyping.Float[Array, '...'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dropout_prob: float | None = None, dropout_rng: Any | None = None, deterministic: bool | None = None, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset] = None, prevent_cse: bool = True, cum_seqlens_q: jaxtyping.Int[Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[Array, 'batch_plus_one'] | None = None, normalize_output: bool = True, fused_backward: bool = False, compute_dtype: numpy.dtype | None = None, optimized: bool = False, mask_value: float | None = None, vmem_limit_bytes: int | None = None, policy: Any | None = None) AttentionOutput[source]#

Performs the attention computation using the selected backend implementation.

Parameters
  • query_states (Array) – Query tensor [batch, seq_q, heads, dim].

  • key_states (Array) – Key tensor [batch, seq_k, heads, dim].

  • value_states (Array) – Value tensor [batch, seq_v, heads, dim].

  • mode (common_types.RUNTIME_MODE_TYPES) – Runtime mode (TRAIN, PREFILL, DECODE).

  • mask_info (MaskInfo, optional) – Container for attention masks and segment IDs. Defaults to None. If provided, contains: - attention_mask: Boolean mask [batch, 1, seq_q, seq_k] - q_segment_ids: Query segment IDs [batch, seq_q] - kv_segment_ids: Key/Value segment IDs [batch, seq_k] - q_positions: Query position indices [batch, seq_q] - kv_positions: Key/Value position indices [batch, seq_k]

  • bias (Array, optional) – Optional attention bias [batch, heads, seq_q, seq_k]. Defaults to None.

  • sliding_window (int | tuple[int, int], optional) – Sliding window size for attention. Defaults to None.

  • cache_metadata (TransformerMetadata | RaggedPagesMetadata, optional) – Cache metadata. Defaults to None.

  • cache_view (TransformerCacheView | RaggedPagesCacheView, optional) – View into KV cache. Defaults to None.

  • init_bias (Callable, optional) – Function to initialize bias tensor. Defaults to None.

  • causal (bool, optional) – Apply causal masking. Defaults to True.

  • softmax_aux (Array, optional) – Auxiliary tensor for softmax (e.g., sink tokens). Defaults to None.

  • softmax_scale (float, optional) – Scaling factor for attention logits. Defaults to None.

  • logits_soft_cap (float, optional) – Soft capping value for attention logits. Defaults to None.

  • dropout_prob (float, optional) – Dropout probability for attention weights. Defaults to None.

  • dropout_rng (PRNGKey, optional) – PRNG key for dropout. Defaults to None.

  • deterministic (bool, optional) – If True, disables dropout. Defaults to None (uses self.deterministic).

  • precision (lax.PrecisionLike, optional) – JAX precision setting. Defaults to None.

  • prevent_cse (bool, optional) – Prevent common subexpression elimination. Defaults to True.

  • cum_seqlens_q (Array, optional) – Cumulative sequence lengths for queries. Defaults to None.

  • cum_seqlens_k (Array, optional) – Cumulative sequence lengths for keys. Defaults to None.

  • normalize_output (bool, optional) – Normalize attention output. Defaults to True.

  • fused_backward (bool, optional) – Use fused backward pass. Defaults to False.

  • compute_dtype (jnp.dtype, optional) – Computation dtype. Defaults to None.

  • optimized (bool, optional) – Use optimized kernel variant. Defaults to False.

  • mask_value (float, optional) – Value for masked positions. Defaults to None.

  • vmem_limit_bytes (int, optional) – VMEM limit in bytes for paged attention. Defaults to None.

  • policy (Any, optional) – Checkpoint policy for gradients. Defaults to None.

Returns

An object containing the attention output tensor and potentially

attention weights (depending on the backend).

Return type

AttentionOutput

easydel.layers.attention.get_optimal_config() tuple[easydel.layers.attention.AttentionMechanisms, numpy.dtype][source]#

Determine optimal attention configuration for hardware.

Analyzes the current JAX backend and hardware to recommend the best attention mechanism and data type for performance.

Returns

  • TPU v3: (FLASH_ATTN2, float32)

  • TPU v4+: (SPLASH, bfloat16)

  • GPU: (FLASH_ATTN2, float16)

  • CPU/other: (VANILLA, bfloat16)

Return type

Tuple of (attention_mechanism, dtype) optimized for current hardware

Example

>>> mechanism, dtype = get_optimal_config()
>>> attn = FlexibleAttentionModule(
...     attention_mechanism=mechanism,
...     dtype=dtype
... )
easydel.layers.attention.tpu_version_check(version: str = 'v4') bool[source]#

Check if running on specified TPU version.

Verifies if the current JAX device matches the specified TPU version for hardware-specific optimizations.

Parameters

version – TPU version string to check (e.g., “v4”, “v5”). Defaults to “v4”.

Returns

True if running on specified TPU version, False otherwise.

Example

>>> if tpu_version_check("v5"):
...     # Use TPU v5 optimizations
...     pass