easydel.layers.attention_operator.modules.vanilla#
- class easydel.layers.attention_operator.modules.vanilla.VanillaAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplA standard, non-optimized implementation of multi-head attention.
This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.
Registered under the name “vanilla”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Delegates to forward_native.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to forward_native.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#
Computes multi-head attention using standard JAX operations.
Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.
- Parameters
q – Query tensor (B, T, H_q, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.
bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If True, disables dropout.
dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.
- Raises
NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.