easydel.layers.attention_operator.modules.vanilla#

class easydel.layers.attention_operator.modules.vanilla.VanillaAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

A standard, non-optimized implementation of multi-head attention.

This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.

Registered under the name “vanilla”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Delegates to forward_native.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to forward_native.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#

Computes multi-head attention using standard JAX operations.

Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.

Parameters
  • q – Query tensor (B, T, H_q, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.

  • bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • deterministic – If True, disables dropout.

  • dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.

Raises

NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Delegates to forward_native.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “vanilla”.