easydel.layers.attention#
- class easydel.layers.attention.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,Enum- AUTO = 'auto'#
- BLOCKWISE = 'blockwise'#
- CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
- CUDNN = 'cudnn'#
- FLASH_ATTN2 = 'flash_attn2'#
- RING = 'ring'#
- SDPA = 'sdpa'#
- SPLASH = 'splash'#
- VANILLA = 'vanilla'#
- class easydel.layers.attention.FlaxAttentionModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module- static build_cache_pos(attention_mask: Array, cache_view: TransformerCacheView = None) Array[source]#
- concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], cache_view: Optional[TransformerCacheView] = None, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, fcm_mask: Optional[Union[Array, ndarray, bool, number]] = None, sliding_windows: Optional[int] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Callable[[], Union[Array, ndarray, bool, number]]][source]#
- property default_key_value_sharding#
- get_sharding_safely(tensor: Array) PartitionSpec[source]#
- make_flexible_sliding_window(attention_mask: Array, cache_view: TransformerCacheView, sliding_window: int)[source]#
- property quantizer#
- class easydel.layers.attention.FlexibleAttentionModule(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleManages different attention mechanisms for efficient computation in EasyDeL models.
This class serves as a central hub for handling various attention mechanisms, including optimized implementations like FlashAttention, SplashAttention, RingAttention, and more traditional approaches like vanilla (dot-product) attention. It provides a unified interface to select and execute the appropriate attention mechanism based on the model’s configuration and hardware platform.
Key Features:
Attention Mechanism Selection: Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints.
Sharding and Partitioning: Integrates with JAX’s sharding capabilities, enabling efficient distribution of computations and data across multiple devices.
Block-wise Computation: Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models.
Performance Optimization: Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration.
Flexibility and Customization: Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases.
Testing and Evaluation: Includes a run_attention_benchmarks method to systematically evaluate different attention mechanisms and help users identify the best-performing option.
The FlexibleAttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX’s sharding capabilities and offering performance enhancements through specialized implementations
- like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options
makes it adaptable to a variety of model architectures and hardware configurations.
- forward(query_states: Union[Array, ndarray, bool, number], key_states: Union[Array, ndarray, bool, number], value_states: Union[Array, ndarray, bool, number], bias: Optional[Union[Array, ndarray, bool, number]] = None, init_bias: Optional[Callable[[], Union[Array, ndarray, bool, number]]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, segment_ids: Optional[Union[Array, ndarray, bool, number]] = None, causal: bool = True, dropout_rng: Optional[PRNGKey] = None) AttentionOutput[source]#
- easydel.layers.attention.get_optimal_config() Tuple[AttentionMechanisms, dtype][source]#
Returns the optimal attention mechanism and dtype for the current JAX device.
- Returns
A tuple of (attention_mechanism, dtype)