easydel.kernels.tpu_ops.__init__#

easydel.kernels.tpu_ops.__init__.pallas_paged_attention(q: Array, k_pages: Array, v_pages: Array, lengths: Array, page_indices: Array, *, sm_scale: float = 1, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int, megacore_mode: str | None = None, inline_seq_dim: bool = True) Array#

Paged grouped query attention.

Parameters
  • q โ€“ A [batch_size, num_heads, head_dim] jax.Array.

  • k_pages โ€“ A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • v_pages โ€“ A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • lengths โ€“ A i32[batch_size] jax.Array the length of each example.

  • page_indices โ€“ A i32[batch_size, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate the page in k_pages or v_pages.

  • sm_scale โ€“ normal softmax scale. By default it is 1.0.

  • mask_value โ€“ The value used for padding in attention. By default it is a very negative floating point number.

  • attn_logits_soft_cap โ€“ The value used for soft capping the attention logits.

  • pages_per_compute_block โ€“ how many pages to be processed in one flash attention block in the pallas kernel.

  • megacore_mode โ€“

    if set, enable megacore to parallelize the computation. Must be one of [โ€˜kv_headโ€™, โ€˜batchโ€™, None]. Caveat: set this only if megacore is enabled, otherwise the kernel may hang. If you are not sure, leave it to None. * None: disable megacore parallelism. * kv_head: megacore parallelism on KV heads; requires number of KV heads

    divisible by 2.

    • batch: megacore parallelism on batch dimension; requires batch divisible by 2.

  • inline_seq_dim โ€“ whether to fuse kernel instances along the sequence dim into one kernel.

Returns

The output of attention([batch_size, num_heads, head_dim]).

easydel.kernels.tpu_ops.__init__.pallas_prefill_attention(q: Array, k_pages: Array, v_pages: Array, length: Array, page_indices: Array, sm_scale: Optional[float] = None)#

Computes paged attention for the prefill phase.

This function wraps the prefill_attention_impl Pallas kernel, handling data layout transformations and launching the kernel. It processes one chunk of the query sequence against the corresponding KV cache pages.

Parameters
  • q โ€“ Query tensor for a chunk of the sequence.

  • k_pages โ€“ Key cache stored in paged layout in HBM.

  • v_pages โ€“ Value cache stored in paged layout in HBM.

  • length โ€“ The total sequence length for the item being processed.

  • page_indices โ€“ Array mapping sequence positions to page indices in k_pages/v_pages.

  • sm_scale โ€“ normal softmax scale. By default it is None or auto.

Returns

The attention output for the query chunk, shape [chunk_size, num_attn_heads * head_dim].

easydel.kernels.tpu_ops.__init__.pallas_ragged_decode(query: Array, key: Array, value: Array, lengths: Array, starts: Array, *, block_size: int = 256, mask_value: float = -2.381976426469702e+38) tuple[jax.Array, jax.Array, jax.Array]#

Dispatches to ragged Multi-Head or Grouped-Query Attention.

This function acts as a caller for either ragged Multi-Head Attention (MHA) or ragged Grouped-Query Attention (GQA) based on the number of query heads and key/value heads. It utilizes Pallas kernels optimized for TPUs to handle ragged inputs efficiently.

Parameters
  • query โ€“ Query tensor with shape [batch_size, num_q_heads, head_dim].

  • key โ€“ Key tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • value โ€“ Value tensor with shape [batch_size, seq_len, num_kv_heads, head_dim].

  • lengths โ€“ Integer tensor with shape [batch_size] indicating the true sequence length for each item in the batch.

  • starts โ€“ Integer tensor indicating start indices.

  • block_size โ€“ The block size to use for Pallas kernels. Defaults to 256.

  • mask_value โ€“ The value to use for masking attention logits. Defaults to a large negative number.

Returns

  • The attention output tensor.
    • The maximum logit values.

    • The softmax denominator values.

The exact shapes depend on whether MHA or GQA is called.

Return type

A tuple containing