easydel.layers.attention_operator.modules.splash#
- class easydel.layers.attention_operator.modules.splash.SplashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation using the Pallas Splash Attention kernel for TPUs.
Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.
Note
This implementation is primarily intended for TPUs.
- It falls back to VanillaAttn under certain conditions:
Query sequence length is 1 (generation mode).
causal is False.
Query sequence length is not divisible by 128 (kernel constraint).
Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.
Registered under the name “splash”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Not implemented for Splash Attention.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Not implemented for Splash Attention.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Not implemented for Splash Attention.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass. Not implemented for Splash Attention.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Splash Attention.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Performs Splash Attention on TPU using the Pallas kernel.
Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.
- Parameters
q – Query tensor (B, T, Hq, D).
k – Key tensor (B, S, Hkv, D).
v – Value tensor (B, S, Hkv, Dv).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.
causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.