easydel.layers.operations.modules.blocksparse_attention#

Splash Attention implementation for TPU acceleration.

This module provides the Splash Attention implementation, a TPU-optimized attention mechanism that leverages the Pallas framework for maximum performance. Splash Attention is specifically designed to take advantage of TPU’s matrix multiplication units and memory hierarchy.

Key features: - TPU-specific optimization using Pallas kernels - Support for Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) - Efficient handling of causal masks - Automatic fallback to vanilla attention for unsupported configurations - Optimized for sequences with lengths divisible by 128

Implementation details: - Uses make_splash_mqa_single_device primitive from JAX experimental - Requires specific block sizes for optimal TPU utilization - Falls back to vanilla attention for:

  • Single token generation (seq_len = 1)

  • Non-causal attention patterns

  • Sequences not divisible by 128

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import BlockSparseAttn
>>>
>>> # Configure for TPU execution
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.bfloat16,
...     softmax_scale=1.0 / math.sqrt(head_dim),
...     blocksize_q=256,
...     blocksize_k=512
... )
>>> splash_attn = BlockSparseAttn(metadata)
>>>
>>> # Use with sequences divisible by 128
>>> output = splash_attn(query, key, value, causal=True)

Note

Splash Attention is only available on TPU devices and will raise NotImplementedError on CPU or GPU backends.

References

  • JAX Pallas documentation for TPU kernels

  • Google Research papers on TPU-optimized attention

class easydel.layers.operations.modules.blocksparse_attention.BlockSparseAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

An attention implementation using the Pallas Splash Attention kernel for TPUs.

Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.

Note

  • This implementation is primarily intended for TPUs.

  • It falls back to VanillaAttn under certain conditions:
    • Query sequence length is 1 (generation mode).

    • causal is False.

    • Query sequence length is not divisible by 128 (kernel constraint).

  • Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.

Registered under the name “blocksparse”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

Splash Attention is TPU-specific and has no GPU implementation.

Parameters
  • *args – Ignored arguments.

  • **kwargs – Ignored keyword arguments.

Raises

NotImplementedError – Always raised as GPU execution is not supported.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

Splash Attention is TPU-specific and has no GPU implementation.

Parameters
  • *args – Ignored arguments.

  • **kwargs – Ignored keyword arguments.

Raises

NotImplementedError – Always raised as GPU execution is not supported.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

Splash Attention is TPU-specific and has no GPU implementation.

Parameters
  • *args – Ignored arguments.

  • **kwargs – Ignored keyword arguments.

Raises

NotImplementedError – Always raised as GPU execution is not supported.

forward_native(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], softmax_aux: jaxtyping.Float[Array, 'num_kv_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None, logits_soft_cap: float | None = None, softmax_scale: float | None = None, sliding_window: int | tuple[int, int] | None = None, causal: bool = True, fused_backward: bool = False, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | None = None, **ignore) AttentionOutput[source]#

Performs Splash Attention on TPU/GPU using the Pallas/Triton kernel.

Handles fallback logic, attention_mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.

Parameters
  • query – Query tensor (B, T, Hq, D).

  • key – Key tensor (B, S, Hkv, D).

  • value – Value tensor (B, S, Hkv, Dv).

  • attention_mask – Optional boolean attention attention_mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.

  • causal – If True, applies causal masking via the kernel’s attention_mask configuration. If False, falls back to VanillaAttn.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

Splash Attention is TPU-specific and has no GPU implementation.

Parameters
  • *args – Ignored arguments.

  • **kwargs – Ignored keyword arguments.

Raises

NotImplementedError – Always raised as GPU execution is not supported.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

Splash Attention is TPU-specific and has no GPU implementation.

Parameters
  • *args – Ignored arguments.

  • **kwargs – Ignored keyword arguments.

Raises

NotImplementedError – Always raised as GPU execution is not supported.

get_impl_metadata() OperationMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The OperationMetadata provided during initialization.

classmethod get_impl_name() str | tuple[str][source]#

Returns the registered name of this attention implementation.

Returns

The string “blocksparse”.