easydel.kernels.tpu_ops.paged_attention_pallas._paged_attention#
- class easydel.kernels.tpu_ops.paged_attention_pallas._paged_attention.PagedAttention[source]#
Bases:
object- build_contiguous_kv_vectorized(pages, page_indices) tuple[jax.Array, jax.Array][source]#
Builds contiguous KV caches from paged KV caches using vectorized operations.
The output sequence length dimension will be max_blocks_per_seq * block_size. The caller needs external knowledge (e.g., original sequence lengths) to correctly interpret or mask the padding positions in the returned tensors.
- Returns
A tuple containing (contiguous_k, contiguous_v).
- easydel.kernels.tpu_ops.paged_attention_pallas._paged_attention.paged_attention(q: Array, k_pages: Array, v_pages: Array, lengths: Array, page_indices: Array, *, sm_scale: float = 1, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int, megacore_mode: str | None = None, inline_seq_dim: bool = True) Array[source]#
Paged grouped query attention.
- Parameters
q โ A [batch_size, num_heads, head_dim] jax.Array.
k_pages โ A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
v_pages โ A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
lengths โ A i32[batch_size] jax.Array the length of each example.
page_indices โ A i32[batch_size, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate the page in k_pages or v_pages.
sm_scale โ normal softmax scale. By default it is 1.0.
mask_value โ The value used for padding in attention. By default it is a very negative floating point number.
attn_logits_soft_cap โ The value used for soft capping the attention logits.
pages_per_compute_block โ how many pages to be processed in one flash attention block in the pallas kernel.
megacore_mode โ
if set, enable megacore to parallelize the computation. Must be one of [โkv_headโ, โbatchโ, None]. Caveat: set this only if megacore is enabled, otherwise the kernel may hang. If you are not sure, leave it to None. * None: disable megacore parallelism. * kv_head: megacore parallelism on KV heads; requires number of KV heads
divisible by 2.
batch: megacore parallelism on batch dimension; requires batch divisible by 2.
inline_seq_dim โ whether to fuse kernel instances along the sequence dim into one kernel.
- Returns
The output of attention([batch_size, num_heads, head_dim]).
- easydel.kernels.tpu_ops.paged_attention_pallas._paged_attention.prefill_attention(q: Array, k_pages: Array, v_pages: Array, length: Array, page_indices: Array, sm_scale: Optional[float] = None)[source]#
Computes paged attention for the prefill phase.
This function wraps the prefill_attention_impl Pallas kernel, handling data layout transformations and launching the kernel. It processes one chunk of the query sequence against the corresponding KV cache pages.
- Parameters
q โ Query tensor for a chunk of the sequence.
k_pages โ Key cache stored in paged layout in HBM.
v_pages โ Value cache stored in paged layout in HBM.
length โ The total sequence length for the item being processed.
page_indices โ Array mapping sequence positions to page indices in k_pages/v_pages.
sm_scale โ normal softmax scale. By default it is None or auto.
- Returns
The attention output for the query chunk, shape [chunk_size, num_attn_heads * head_dim].