easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas#

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.get_mha_cost_estimate(shape_dtype)[source]#

Provides a Pallas CostEstimate for Multi-Head Attention.

Calculates an approximate cost estimate (FLOPs, transcendentals, bytes accessed) for a Multi-Head Attention operation based on the input shapes and dtypes. This is used to inform the Pallas compiler.

Parameters

shape_dtype – A tuple containing JAX ShapeDtypeStruct objects for the query, key, value, and lengths tensors. Expected shapes are: - query: [batch_size, 1, num_heads, head_dim] - key: [batch_size, seq_len, num_heads, head_dim] - value: [batch_size, seq_len, num_heads, head_dim] - lengths: [batch_size]

Returns

A pl.CostEstimate object containing the estimated costs.

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_flash_attention_kernel(lengths_ref, starts_ref, q_ref, k_ref, v_ref, o_ref, m_ref, l_ref, *, block_size: int, mask_value: float)[source]#

Pallas kernel implementing Flash Attention for ragged sequences.

This kernel performs the core Flash Attention computation block by block, handling variable sequence lengths specified by lengths_ref. It updates the output (o_ref), maximum logit (m_ref), and softmax denominator (l_ref) incrementally.

Parameters
  • lengths_ref – Reference to the lengths tensor [batch_size].

  • q_ref – Reference to the query tensor block.

  • k_ref – Reference to the key tensor block.

  • v_ref – Reference to the value tensor block.

  • o_ref – Reference to the output tensor block (accumulator).

  • m_ref – Reference to the maximum logit tensor block (accumulator).

  • l_ref – Reference to the softmax denominator tensor block (accumulator).

  • block_size – The size of the blocks to process along the sequence length dimension.

  • mask_value – The value used for masking padding tokens.

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_gqa(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Ragged group query attention.

Parameters
  • q – A [batch_size, num_heads_q, head_dim] jax.Array.

  • k – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.

  • v – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.

  • lengths – A i32[batch_size] jax.Array.

  • block_size – Value defining the Pallas block length in the seq_len dimension

  • mask_value – The value used for padding in attention. By default it is a very negative floating point number.

Returns

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_mha(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Ragged multi head attention.

Parameters
  • q – A [batch_size, 1, num_heads, head_dim] jax.Array.

  • k – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • v – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • lengths – A i32[batch_size] jax.Array.

  • block_size – Value defining the Pallas block length in the seq_len dimension

  • mask_value – The value used for padding in attention. By default it is a very negative floating point number.

Returns

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.ragged_mqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38, cost_estimate: jax._src.pallas.core.CostEstimate | None = None) tuple[jax.Array, jax.Array, jax.Array][source]#

Ragged multi query attention.

Parameters
  • q – A [batch_size, 1, head_dim] jax.Array.

  • k – A [batch_size, seq_len, head_dim] jax.Array.

  • v – A [batch_size, seq_len, head_dim] jax.Array.

  • lengths – A i32[batch_size] jax.Array.

  • starts – A i32[batch_size] jax.Array.

  • mask_value – The value used for padding in attention. By default it is a very negative floating point number.

  • cost_estimate – A Pallas TPU cost estimate based on a reference implementation

Returns

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_gqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Vanilla attention GQA implementation for reference.

Returns

[batch_size, 1, num_q_heads, head_dim] m: [batch_size, 1, num_q_heads, 1] l: [batch_size, 1, num_q_heads, 1]

Return type

o

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_mha(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Multi head attention reference.

Returns

[batch_size, 1, num_heads, head_dim] m: [batch_size, 1, num_heads, 1] l: [batch_size, 1, num_heads, 1]

Return type

o

easydel.kernels.tpu_ops.ragged_attention_pallas._forward_pallas.reference_mqa(q: Array, k: Array, v: Array, lengths: Array, starts: Array, *, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Multi query attention reference (Called by vmap in reference_mha).