easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas#
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.get_mha_cost_estimate(shape_dtype)[source]#
Provides a Pallas CostEstimate for Multi-Head Attention.
Calculates an approximate cost estimate (FLOPs, transcendentals, bytes accessed) for a Multi-Head Attention operation based on the input shapes and dtypes. This is used to inform the Pallas compiler.
- Parameters
shape_dtype – A tuple containing JAX ShapeDtypeStruct objects for the query, key, value, and lengths tensors. Expected shapes are: - query: [batch_size, 1, num_heads, head_dim] - key: [batch_size, seq_len, num_heads, head_dim] - value: [batch_size, seq_len, num_heads, head_dim] - lengths: [batch_size]
- Returns
A pl.CostEstimate object containing the estimated costs.
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_flash_attention_kernel(lengths_ref, starts_ref, q_ref, k_ref, v_ref, o_ref, m_ref, l_ref, *, block_size: int, mask_value: float)[source]#
Pallas kernel implementing Flash Attention for ragged sequences.
This kernel performs the core Flash Attention computation block by block, handling variable sequence lengths specified by lengths_ref. It updates the output (o_ref), maximum logit (m_ref), and softmax denominator (l_ref) incrementally.
- Parameters
lengths_ref – Reference to the lengths tensor [batch_size].
q_ref – Reference to the query tensor block.
k_ref – Reference to the key tensor block.
v_ref – Reference to the value tensor block.
o_ref – Reference to the output tensor block (accumulator).
m_ref – Reference to the maximum logit tensor block (accumulator).
l_ref – Reference to the softmax denominator tensor block (accumulator).
block_size – The size of the blocks to process along the sequence length dimension.
mask_value – The value used for masking padding tokens.
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_gqa(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#
Ragged group query attention.
- Parameters
q – A [batch_size, num_heads_q, head_dim] jax.Array.
k – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.
v – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.
lengths – A i32[batch_size] jax.Array.
block_size – Value defining the Pallas block length in the seq_len dimension
mask_value – The value used for padding in attention. By default it is a very negative floating point number.
- Returns
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_mha(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#
Ragged multi head attention.
- Parameters
q – A [batch_size, 1, num_heads, head_dim] jax.Array.
k – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
v – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
lengths – A i32[batch_size] jax.Array.
block_size – Value defining the Pallas block length in the seq_len dimension
mask_value – The value used for padding in attention. By default it is a very negative floating point number.
- Returns
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_mqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38, cost_estimate: jax._src.pallas.core.CostEstimate | None = None) tuple[jax.Array, jax.Array, jax.Array][source]#
Ragged multi query attention.
- Parameters
q – A [batch_size, 1, head_dim] jax.Array.
k – A [batch_size, seq_len, head_dim] jax.Array.
v – A [batch_size, seq_len, head_dim] jax.Array.
lengths – A i32[batch_size] jax.Array.
starts – A i32[batch_size] jax.Array.
mask_value – The value used for padding in attention. By default it is a very negative floating point number.
cost_estimate – A Pallas TPU cost estimate based on a reference implementation
- Returns
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_gqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#
Vanilla attention GQA implementation for reference.
- Returns
[batch_size, 1, num_q_heads, head_dim] m: [batch_size, 1, num_q_heads, 1] l: [batch_size, 1, num_q_heads, 1]
- Return type
o
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_mha(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#
Multi head attention reference.
- Returns
[batch_size, 1, num_heads, head_dim] m: [batch_size, 1, num_heads, 1] l: [batch_size, 1, num_heads, 1]
- Return type
o
- easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_mqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#
Multi query attention reference (Called by vmap in reference_mha).