easydel.layers.attention_operator.modules.splash#

class easydel.layers.attention_operator.modules.splash.SplashAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An attention implementation using the Pallas Splash Attention kernel for TPUs.

Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.

Note

  • This implementation is primarily intended for TPUs.

  • It falls back to VanillaAttn under certain conditions:
    • Query sequence length is 1 (generation mode).

    • causal is False.

    • Query sequence length is not divisible by 128 (kernel constraint).

  • Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.

Registered under the name “splash”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Not implemented for Splash Attention.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Not implemented for Splash Attention.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

forward_native(*args, **kwargs) AttentionOutput[source]#

Native (CPU) forward pass. Not implemented for Splash Attention.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented for Splash Attention.

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, cache_view: Optional[TransformerCacheView] = None, **ignore) AttentionOutput[source]#

Performs Splash Attention on TPU using the Pallas kernel.

Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.

Parameters
  • q – Query tensor (B, T, Hq, D).

  • k – Key tensor (B, S, Hkv, D).

  • v – Value tensor (B, S, Hkv, Dv).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.

  • causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “splash”.