easydel.layers.operations.modules.scaled_dot_product_attention#

Scaled Dot-Product Attention implementation using JAX’s optimized primitives.

This module provides an attention implementation that leverages JAX’s jax.nn.dot_product_attention API, which automatically dispatches to the most efficient implementation available on the current hardware.

Key features: - Automatic backend selection (XLA, cuDNN, Flash Attention) - Support for multiple hardware backends (TPU, GPU, CPU) - Efficient handling through JAX’s SDPA primitive - Automatic optimization based on hardware capabilities - Compatible with various attention patterns (causal, masked, biased)

Implementation details: - On CUDA GPUs: Uses cuDNN’s optimized attention kernels - On TPUs/CPUs: Uses XLA’s optimized implementations - Automatically selects Flash Attention when available - Handles sharding for distributed computation

The implementation is registered under multiple names: - “sdpa”: Scaled Dot-Product Attention (generic name) - “cudnn”: Specifically for CUDA/cuDNN backend - “cuda_flash_attn2”: For Flash Attention v2 on CUDA

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import ScaledDotProductAttn
>>>
>>> # Configure for efficient SDPA
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.float16,
...     softmax_scale=1.0 / math.sqrt(head_dim)
... )
>>> sdpa_attn = ScaledDotProductAttn(metadata)
>>>
>>> # Automatically uses best available implementation
>>> output = sdpa_attn(query, key, value, attention_mask=attention_mask, causal=True)

Note

JAX will automatically select the best implementation based on: - Hardware availability (GPU with cuDNN, TPU, CPU) - Tensor shapes and dtypes - JAX version and installed libraries - Specific operation parameters (causal, attention_mask type)

References

  • JAX documentation on dot_product_attention

  • NVIDIA cuDNN documentation

  • Flash Attention papers and implementations

class easydel.layers.operations.modules.scaled_dot_product_attention.ScaledDotProductAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

An attention implementation that leverages jax.nn.dot_product_attention.

This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).

It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).

Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (XLA implementation).

Parameters
  • *args – Positional arguments for attention calculation.

  • **kwargs – Keyword arguments for attention calculation.

Returns

Result from XLA-optimized implementation.

Return type

AttentionOutput

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (AUTO-DETECT implementation).

Parameters
  • *args – Positional arguments for attention calculation.

  • **kwargs – Keyword arguments for attention calculation.

Returns

Result from XLA-optimized implementation.

Return type

AttentionOutput

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for attention calculation.

  • **kwargs – Keyword arguments for attention calculation.

Returns

Result from CUDA-optimized implementation.

Return type

AttentionOutput

forward_native(query: Float[Array, 'batch seq_len num_q_heads head_dim'], key: Float[Array, 'batch kv_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_len num_kv_heads head_dim'], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, cum_seqlens_q: jaxtyping.Int[Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[Array, 'batch'] | None = None, **ignore) AttentionOutput[source]#

Computes attention using jax.nn.dot_product_attention with the “xla” implementation.

This is typically used for CPU and TPU backends. It applies sharding via shard_map.

Parameters
  • query – Query tensor (B, T, H, D).

  • key – Key tensor (B, S, H_kv, D).

  • value – Value tensor (B, S, H_kv, D_v).

  • attention_mask – Optional boolean attention attention_mask (broadcastable to B, 1, T, S). Passed directly to the primitive.

  • bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.

  • init_bias – Optional callable to initialize bias if attention_mask/bias are None.

  • causal – If True and bias is None, applies causal masking within the primitive.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native.

Future versions may include ROCm-specific optimizations.

Parameters
  • *args – Positional arguments for attention calculation.

  • **kwargs – Keyword arguments for attention calculation.

Returns

Result from XLA implementation.

Return type

AttentionOutput

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native (XLA implementation).

Parameters
  • *args – Positional arguments for attention calculation.

  • **kwargs – Keyword arguments for attention calculation.

Returns

Result from XLA-optimized implementation.

Return type

AttentionOutput

get_impl_metadata() OperationMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The OperationMetadata provided during initialization.

classmethod get_impl_name() str | tuple[str][source]#

Returns the registered name(s) for this implementation.

Returns

(“sdpa”, “cudnn”, “cuda_flash_attn2”).

Return type

A tuple of strings