easydel.layers.attention#

class easydel.layers.attention.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of available attention mechanisms.

AUTO#

Automatically selects the best mechanism based on the backend.

FLASH_ATTN2#

FlashAttention-2 implementation.

RING#

RingAttention implementation.

VANILLA#

Standard dot-product attention.

SPLASH#

SplashAttention implementation (optimized for TPUs).

CUDNN#

cuDNN implementation (GPU specific).

BLOCKWISE#

Blockwise attention computation.

SDPA#

Scaled Dot Product Attention (potentially uses JAX native SDPA).

CUDA_FLASH_ATTN2#

CUDA specific FlashAttention-2 implementation.

PAGED_ATTENTION#

Paged attention for fast inference.

AUTO = 'auto'#
BLOCKWISE = 'blockwise'#
CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
CUDNN = 'cudnn'#
FLASH_ATTN2 = 'flash_attn2'#
PAGED_ATTENTION = 'paged_attention'#
REGRESSIVE_DECODE = 'autoregressive_decodeattn'#
RING = 'ring'#
SDPA = 'sdpa'#
SPLASH = 'splash'#
VANILLA = 'vanilla'#
class easydel.layers.attention.AttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Base class for Flax attention modules in EasyDeL, providing common utilities.

This class offers helper functions and attributes commonly needed by attention implementations within Flax, such as handling KV caching, sharding, mask manipulation, and head manipulation. Concrete attention implementations often inherit from this class.

config#

Configuration object for the attention module.

Type

SC | EasyDeLBaseConfig

cached_key#

Flax Cache for storing past key states (wont be used).

Type

nn.Cache[Array] | None

cached_value#

Flax Cache for storing past value states (wont be used).

Type

nn.Cache[Array] | None

cache_index#

Flax Cache for tracking the current index in the cache (wont be used).

Type

nn.Cache[Array] | None

static apply_complex_rotary(xq: Array, xk: Array, freqs_cis: Array) Tuple[Array, Array][source]#
apply_qk_shardings(q: Array, k: Array) Tuple[Array, Array][source]#
apply_qkv_shardings(q: Array, k: Array, v: Array) Tuple[Array, Array, Array][source]#
static build_cache_pos(attention_mask: Array, cache_view: TransformerCacheView = None) Array[source]#

Calculates the position indices within the sequence for cache-aware operations.

Parameters
  • attention_mask (jax.Array) – The attention mask (typically [batch, heads, q_len, k_len]).

  • cache_view (TransformerCacheView, optional) – The current KV cache view. Defaults to None.

Returns

An array representing the position of each token in the sequence,

adjusted by the cache index if provided. Shape usually [batch, q_len].

Return type

jax.Array

concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], cache_view: Optional[Union[TransformerCacheView, PagedAttentionCacheView]] = None, cache_metadata: Optional[Union[TransformerMetadata, PagedAttentionMetadata]] = None, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, fcm_mask: Optional[Union[Array, ndarray, bool, number]] = None, sliding_window: Optional[int] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Callable[[], Union[Array, ndarray, bool, number]]][source]#

Prepares inputs for attention calculation, handling KV caching and mask merging.

This function combines the current query, key, and value with cached states (if applicable), merges various masks (attention, causal, FCM, sliding window), and returns the final key, value, attention mask, and a function to initialize the attention bias.

Parameters
  • query (Array) – Current query states [Batch, q_len, Heads, Dim].

  • key (Array) – Current key states [Batch, kv_len, Heads, Dim].

  • value (Array) – Current value states [Batch, kv_len, Heads, Dim].

  • attention_mask (Array) – Base attention mask (e.g., padding mask) [Batch, kv_len] or compatible.

  • cache_view (tp.Optional[TransformerCacheView], optional) – View into the KV cache. If None, caching is disabled. Defaults to None.

  • causal_mask (tp.Optional[Array], optional) – Causal mask [1, 1, q_len, kv_len]. Defaults to None.

  • token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment masking [Batch, q_len]. Defaults to None.

  • fcm_mask (tp.Optional[Array], optional) – Fused-Context-Mask (specific use case) [Batch, 1, q_len, kv_len]. Defaults to None.

  • sliding_window (tp.Optional[int], optional) – Size of the sliding attention window. If None, not applied. Defaults to None.

Returns

  • key_states (Array): Final key states (potentially from cache).

  • value_states (Array): Final value states (potentially from cache).

  • attention_mask (Array): The final combined attention mask [Batch, Heads, q_len, kv_len].

  • init_attention_bias (Callable): Function to create the attention bias tensor.

Return type

tp.Tuple[Array, Array, Array, tp.Callable[[], Array]]

property default_key_value_sharding#

Defines the default JAX sharding for key and value tensors.

Uses the partition specifications defined in the configuration’s partition_axis.

Returns

The default sharding configuration for K/V tensors.

Return type

NamedSharding

get_sharding_safely(tensor: Array) PartitionSpec[source]#

Retrieves the PartitionSpec of a tensor, falling back to the default KV sharding.

Parameters

tensor (jax.Array) – The tensor whose sharding spec is needed.

Returns

The sharding specification of the tensor.

Return type

PartitionSpec

make_flexible_sliding_window(attention_mask: Array, cache_view: TransformerCacheView, sliding_window: int)[source]#

Applies a sliding window mask to the attention mask, considering cache state.

Parameters
  • attention_mask (jax.Array) – The original attention mask.

  • cache_view (TransformerCacheView) – The current view of the KV cache.

  • sliding_window (int) – The size of the sliding window.

Returns

  • The attention mask combined with the sliding window mask.

  • A function (init_attention_bias) to create the corresponding attention bias.

Return type

tp.Tuple[jax.Array, tp.Callable[[], jax.Array]]

property quantizer#

Provides an EasyQuantizer instance based on the module’s configuration.

Used for quantizing KV cache entries if enabled in the config.

Returns

The quantizer instance.

Return type

EasyQuantizer

static repeat_key_value(key, value, num_reps: int)[source]#

Repeats key and value tensors for Grouped Query Attention (GQA).

Expands the head dimension by repeating num_reps times. Uses einops for concise repetition.

Parameters
  • key (Array) – Key tensor [Batch, Seq, NumKVHeads, Dim].

  • value (Array) – Value tensor [Batch, Seq, NumKVHeads, Dim].

  • num_reps (int) – The number of times to repeat each KV head (num_attention_heads / num_kv_heads).

Returns

Repeated key and value tensors, each with shape

[Batch, Seq, NumKVHeads * num_reps, Dim].

Return type

tp.Tuple[Array, Array]

shard_attention_prod(attn_output: Array) Array[source]#

Applies sharding constraints to the attention output tensor.

This is typically done before projecting the attention output back to the hidden dimension size.

Parameters

attn_output (jax.Array) – The output from the attention mechanism, usually with shape [Batch, SeqLen, NumHeads * DimPerHead].

Returns

The input tensor with applied sharding constraints based on the config.

Return type

jax.Array

class easydel.layers.attention.FlexibleAttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Manages different attention mechanisms for efficient computation in EasyDeL models.

This class serves as a central hub for handling various attention mechanisms, including optimized implementations like FlashAttention, SplashAttention, RingAttention, and more traditional approaches like vanilla (dot-product) attention. It provides a unified interface to select and execute the appropriate attention mechanism based on the model’s configuration and hardware platform.

Key Features:

  • Attention Mechanism Selection: Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints.

  • Sharding and Partitioning: Integrates with JAX’s sharding capabilities, enabling efficient distribution of computations and data across multiple devices.

  • Block-wise Computation: Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models.

  • Performance Optimization: Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration.

  • Flexibility and Customization: Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases.

  • Testing and Evaluation: Includes a run_attention_benchmarks method to systematically evaluate different attention mechanisms and help users identify the best-performing option.

The AttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX’s sharding capabilities and offering performance enhancements through specialized implementations like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options makes it adaptable to a variety of model architectures and hardware configurations. .. attribute:: impl

The chosen attention implementation backend instance.

type

AttentionBackend

deterministic#

Flag indicating whether dropout should be applied (False) or not (True). Currently hardcoded to True.

Type

bool

metadata#

Metadata derived from the configuration, used by the backend.

Type

AttentionMetadata

forward(query_states: Union[Array, ndarray, bool, number], key_states: Union[Array, ndarray, bool, number], value_states: Union[Array, ndarray, bool, number], mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']], bias: Optional[Union[Array, ndarray, bool, number]] = None, sliding_window: Optional[int] = None, cache_metadata: Optional[Union[TransformerMetadata, PagedAttentionMetadata]] = None, cache_view: Optional[Union[TransformerCacheView, PagedAttentionCacheView]] = None, init_bias: Optional[Callable[[], Union[Array, ndarray, bool, number]]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, segment_ids: Optional[Union[Array, ndarray, bool, number]] = None, causal: bool = True, dropout_rng: Optional[PRNGKey] = None) AttentionOutput[source]#

Performs the attention computation using the selected backend implementation.

Parameters
  • query_states (Array) – Query tensor.

  • key_states (Array) – Key tensor.

  • value_states (Array) – Value tensor.

  • bias (tp.Optional[Array], optional) – Optional attention bias. Defaults to None.

  • init_bias (tp.Optional[tp.Callable[[], Array]], optional) – Optional function to initialize bias. Defaults to None.

  • attention_mask (tp.Optional[Array], optional) – Mask to prevent attention to certain positions. Defaults to None.

  • segment_ids (tp.Optional[Array], optional) – Segment IDs for segment-based attention (RingAttention). Defaults to None.

  • causal (bool, optional) – If True, applies a causal mask. Defaults to True.

  • dropout_rng (tp.Optional[random.PRNGKey], optional) – PRNG key for dropout. Defaults to None.

Returns

An object containing the attention output tensor and potentially

attention weights (depending on the backend).

Return type

AttentionOutput

easydel.layers.attention.SC#

Type variable for configuration objects.

alias of TypeVar(‘SC’)

easydel.layers.attention.get_optimal_config() Tuple[AttentionMechanisms, dtype][source]#

Determines the recommended attention mechanism and dtype for the current JAX backend.

Returns

A tuple containing the recommended

AttentionMechanisms enum member and the recommended jnp.dtype.

Return type

tp.Tuple[AttentionMechanisms, jnp.dtype]

easydel.layers.attention.tpu_version_check(version: str = 'v4')[source]#

Checks if the local JAX device matches the specified TPU version.

Parameters

version (str, optional) – The TPU version string to check against (e.g., “v4”). Defaults to “v4”.

Returns

True if the device kind of the first local device contains the

specified version string (case-insensitive), False otherwise.

Return type

bool