easydel.kernels.tpu_ops.ragged_attention_pallas.__init__#
- easydel.kernels.tpu_ops.ragged_attention_pallas.__init__.pallas_ragged_decode(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array]#
Dispatches to ragged Multi-Head or Grouped-Query Attention.
This function acts as a caller for either ragged Multi-Head Attention (MHA) or ragged Grouped-Query Attention (GQA) based on the number of query heads and key/value heads. It utilizes Pallas kernels optimized for TPUs to handle ragged inputs efficiently.
- Parameters
query – Query tensor with shape [batch_size, num_q_heads, head_dim].
key – Key tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].
value – Value tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].
lengths – Integer tensor with shape [batch_size] indicating the true sequence length for each item in the batch.
starts – Integer tensor indicating start indices.
block_size – The block size to use for Pallas kernels. Defaults to 256.
mask_value – The value to use for masking attention logits. Defaults to a large negative number.
- Returns
- The attention output tensor.
The maximum logit values.
The softmax denominator values.
The exact shapes depend on whether MHA or GQA is called.
- Return type
A tuple containing