easydel.layers.attention#

class easydel.layers.attention.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

AUTO = 'auto'#
BLOCKWISE = 'blockwise'#
CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
CUDNN = 'cudnn'#
FLASH_ATTN2 = 'flash_attn2'#
RING = 'ring'#
SDPA = 'sdpa'#
SPLASH = 'splash'#
VANILLA = 'vanilla'#
class easydel.layers.attention.FlaxAttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], cache_view: Optional[TransformerCacheView] = None, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, fcm_mask: Optional[Union[Array, ndarray, bool, number]] = None, sliding_windows: Optional[int] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Callable[[], Union[Array, ndarray, bool, number]]][source]#
property default_key_value_sharding#
get_sharding_safely(tensor: Array) PartitionSpec[source]#
property quantizer#
static repeat_key_value(key, value, num_reps: int)[source]#
shard_attention_prod(attn_output: Array) Array[source]#

shards attention output before passing that to output_proj

Parameters

attn_output (jax.Array) – merged output of dot product attention with 3 dims, (batch, seqlen, hidden_size).

Returns

sharded version of attn_output

Return type

jax.Array

class easydel.layers.attention.FlexibleAttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Manages different attention mechanisms for efficient computation in EasyDeL models.

This class serves as a central hub for handling various attention mechanisms, including optimized implementations like FlashAttention, SplashAttention, RingAttention, and more traditional approaches like vanilla (dot-product) attention. It provides a unified interface to select and execute the appropriate attention mechanism based on the model’s configuration and hardware platform.

Key Features:

  • Attention Mechanism Selection: Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints.

  • Sharding and Partitioning: Integrates with JAX’s sharding capabilities, enabling efficient distribution of computations and data across multiple devices.

  • Block-wise Computation: Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models.

  • Performance Optimization: Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration.

  • Flexibility and Customization: Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases.

  • Testing and Evaluation: Includes a run_attention_benchmarks method to systematically evaluate different attention mechanisms and help users identify the best-performing option.

The FlexibleAttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX’s sharding capabilities and offering performance enhancements through specialized implementations

like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options

makes it adaptable to a variety of model architectures and hardware configurations.

forward(query_states: Union[Array, ndarray, bool, number], key_states: Union[Array, ndarray, bool, number], value_states: Union[Array, ndarray, bool, number], bias: Optional[Union[Array, ndarray, bool, number]] = None, init_bias: Optional[Callable[[], Union[Array, ndarray, bool, number]]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, segment_ids: Optional[Union[Array, ndarray, bool, number]] = None, causal: bool = True, dropout_rng: Optional[PRNGKey] = None) AttentionOutput[source]#
easydel.layers.attention.get_optimal_config() Tuple[AttentionMechanisms, dtype][source]#

Returns the optimal attention mechanism and dtype for the current JAX device.

Returns

A tuple of (attention_mechanism, dtype)

easydel.layers.attention.tpu_version_check(version: str = 'v4')[source]#