easydel.kernels.tpu_ops.ragged_attention_pallas._caller

easydel.kernels.tpu_ops.ragged_attention_pallas._caller#

easydel.kernels.tpu_ops.ragged_attention_pallas._caller.ragged_attention(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Dispatches to ragged Multi-Head or Grouped-Query Attention.

This function acts as a caller for either ragged Multi-Head Attention (MHA) or ragged Grouped-Query Attention (GQA) based on the number of query heads and key/value heads. It utilizes Pallas kernels optimized for TPUs to handle ragged inputs efficiently.

Parameters
  • query – Query tensor with shape [batch_size, num_q_heads, head_dim].

  • key – Key tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • value – Value tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • lengths – Integer tensor with shape [batch_size] indicating the true sequence length for each item in the batch.

  • starts – Integer tensor indicating start indices.

  • block_size – The block size to use for Pallas kernels. Defaults to 256.

  • mask_value – The value to use for masking attention logits. Defaults to a large negative number.

Returns

  • The attention output tensor.
    • The maximum logit values.

    • The softmax denominator values.

The exact shapes depend on whether MHA or GQA is called.

Return type

A tuple containing

easydel.kernels.tpu_ops.ragged_attention_pallas._caller.reference_ragged_attention(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array][source]#

Dispatches to ragged Multi-Head or Grouped-Query Attention.

This function acts as a caller for either ragged Multi-Head Attention (MHA) or ragged Grouped-Query Attention (GQA) based on the number of query heads and key/value heads. It utilizes Pallas kernels optimized for TPUs to handle ragged inputs efficiently.

Parameters
  • query – Query tensor with shape [batch_size, num_q_heads, head_dim].

  • key – Key tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • value – Value tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • lengths – Integer tensor with shape [batch_size] indicating the true sequence length for each item in the batch.

  • starts – Integer tensor indicating start indices.

  • mask_value – The value to use for masking attention logits. Defaults to a large negative number.

Returns

  • The attention output tensor.
    • The maximum logit values.

    • The softmax denominator values.

The exact shapes depend on whether MHA or GQA is called.

Return type

A tuple containing