easydel.layers.operations.modules.scaled_dot_product_attention#
Scaled Dot-Product Attention implementation using JAX’s optimized primitives.
This module provides an attention implementation that leverages JAX’s jax.nn.dot_product_attention API, which automatically dispatches to the most efficient implementation available on the current hardware.
Key features: - Automatic backend selection (XLA, cuDNN, Flash Attention) - Support for multiple hardware backends (TPU, GPU, CPU) - Efficient handling through JAX’s SDPA primitive - Automatic optimization based on hardware capabilities - Compatible with various attention patterns (causal, masked, biased)
Implementation details: - On CUDA GPUs: Uses cuDNN’s optimized attention kernels - On TPUs/CPUs: Uses XLA’s optimized implementations - Automatically selects Flash Attention when available - Handles sharding for distributed computation
The implementation is registered under multiple names: - “sdpa”: Scaled Dot-Product Attention (generic name) - “cudnn”: Specifically for CUDA/cuDNN backend - “cuda_flash_attn2”: For Flash Attention v2 on CUDA
Example
>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import ScaledDotProductAttn
>>>
>>> # Configure for efficient SDPA
>>> metadata = OperationMetadata(
... runtime_dtype=jnp.float16,
... softmax_scale=1.0 / math.sqrt(head_dim)
... )
>>> sdpa_attn = ScaledDotProductAttn(metadata)
>>>
>>> # Automatically uses best available implementation
>>> output = sdpa_attn(query, key, value, attention_mask=attention_mask, causal=True)
Note
JAX will automatically select the best implementation based on: - Hardware availability (GPU with cuDNN, TPU, CPU) - Tensor shapes and dtypes - JAX version and installed libraries - Specific operation parameters (causal, attention_mask type)
References
JAX documentation on dot_product_attention
NVIDIA cuDNN documentation
Flash Attention papers and implementations
- class easydel.layers.operations.modules.scaled_dot_product_attention.ScaledDotProductAttn(metadata: OperationMetadata)[source]#
Bases:
OperationImplAn attention implementation that leverages jax.nn.dot_product_attention.
This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).
It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).
Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (XLA implementation).
- Parameters
*args – Positional arguments for attention calculation.
**kwargs – Keyword arguments for attention calculation.
- Returns
Result from XLA-optimized implementation.
- Return type
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (AUTO-DETECT implementation).
- Parameters
*args – Positional arguments for attention calculation.
**kwargs – Keyword arguments for attention calculation.
- Returns
Result from XLA-optimized implementation.
- Return type
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- Parameters
*args – Positional arguments for attention calculation.
**kwargs – Keyword arguments for attention calculation.
- Returns
Result from CUDA-optimized implementation.
- Return type
- forward_native(query: Float[Array, 'batch seq_len num_q_heads head_dim'], key: Float[Array, 'batch kv_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_len num_kv_heads head_dim'], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, cum_seqlens_q: jaxtyping.Int[Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[Array, 'batch'] | None = None, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “xla” implementation.
This is typically used for CPU and TPU backends. It applies sharding via shard_map.
- Parameters
query – Query tensor (B, T, H, D).
key – Key tensor (B, S, H_kv, D).
value – Value tensor (B, S, H_kv, D_v).
attention_mask – Optional boolean attention attention_mask (broadcastable to B, 1, T, S). Passed directly to the primitive.
bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.
init_bias – Optional callable to initialize bias if attention_mask/bias are None.
causal – If True and bias is None, applies causal masking within the primitive.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native.
Future versions may include ROCm-specific optimizations.
- Parameters
*args – Positional arguments for attention calculation.
**kwargs – Keyword arguments for attention calculation.
- Returns
Result from XLA implementation.
- Return type
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native (XLA implementation).
- Parameters
*args – Positional arguments for attention calculation.
**kwargs – Keyword arguments for attention calculation.
- Returns
Result from XLA-optimized implementation.
- Return type
- get_impl_metadata() OperationMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The OperationMetadata provided during initialization.