easydel.layers.attention_operator.modules.flash#
- class easydel.layers.attention_operator.modules.flash.FlashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn implementation of Flash Attention V2 using specialized JAX primitives.
This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native, which raises an error.
- Raises
NotImplementedError – Via forward_native.
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.
Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, instructs the kernel to apply a causal mask.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Native (CPU) forward pass for Flash Attention. Not implemented.
- Raises
NotImplementedError – Flash Attention is not supported on CPU via this implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented, falls back to native (error).
- Raises
NotImplementedError – Via forward_native. ROCm support requires a specific kernel.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.
Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.