easydel.kernels.tpu_ops.paged_attention._forward_pallas#
- class easydel.kernels.tpu_ops.paged_attention._forward_pallas.MultiPageAsyncCopyDescriptor(pages_hbm_ref, vmem_buffer, sem, page_indices, page_indices_start_offset, num_pages_to_load, head_index)[source]#
Bases:
objectManages asynchronous copies of multiple K/V pages from HBM to VMEM.
This class simplifies the process of initiating and waiting for multiple asynchronous DMA transfers (copies) for pages belonging to the Key or Value cache. It takes a list of page indices and orchestrates the copies into a specified VMEM buffer.
- _vmem_buffer#
The destination VMEM buffer slice for the copies.
- _num_pages_to_load#
The number of pages to copy.
- _pages_hbm_ref#
A Pallas reference to the K or V page cache in HBM.
- _sem#
The semaphore used to coordinate the asynchronous copies.
- _page_indices#
A Pallas reference to the array containing page indices.
- _page_indices_start_offset#
The starting offset within _page_indices for the current set of pages.
- _async_copies#
A list of AsyncCopy objects, one for each page.
- easydel.kernels.tpu_ops.paged_attention._forward_pallas.paged_flash_attention_kernel(lengths_ref, page_indices_ref, buffer_index_ref, step_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, sem, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None, program_ids=())[source]#
Pallas kernel for paged attention, likely for the decode phase.
This kernel computes attention for a single query token against paged Key-Value caches stored in HBM. It processes the KV cache in blocks of pages, using double buffering for asynchronous data loading and FlashAttention-style online softmax calculation.
The kernel grid is expected to be (num_cores, batch_size // b_step, num_kv_heads // h_step, num_blocks_per_sequence). megacore_mode determines how work is distributed across cores (by batch or by KV head).
- Parameters
lengths_ref โ SMEM Ref to sequence lengths for each batch item.
page_indices_ref โ Ref to page indices mapping sequence positions to HBM pages.
buffer_index_ref โ SMEM Ref storing the current VMEM buffer index (0 or 1) for double buffering.
step_ref โ SMEM Ref storing the current step/block index being processed.
q_ref โ VMEM Ref to the query vector(s) for the current token.
k_pages_hbm_ref โ HBM Ref to the Key cache pages.
v_pages_hbm_ref โ HBM Ref to the Value cache pages.
o_ref โ VMEM Ref to store the computed output attention vector(s).
m_ref โ VMEM Ref to store the running maximum logit (part of online softmax).
l_ref โ VMEM Ref to store the running sum of exp(logit - max_logit) (part of online softmax).
k_vmem_buffer โ VMEM Ref for the double buffer used to load Key pages.
v_vmem_buffer โ VMEM Ref for the double buffer used to load Value pages.
sem โ Pallas Ref for the semaphore used for async copy synchronization.
batch_size โ Total batch size.
pages_per_compute_block โ Number of KV cache pages processed per iteration.
pages_per_sequence โ Maximum number of pages allocated per sequence.
mask_value โ Value to use for masking attention logits (e.g., -inf).
attn_logits_soft_cap โ If not None, apply tanh capping to logits.
megacore_mode โ How to distribute work across TPU cores (โbatchโ or โkv_headโ).
program_ids โ Optional tuple to directly provide program IDs, used when this kernel is called from another kernel (like the inline version).
- easydel.kernels.tpu_ops.paged_attention._forward_pallas.paged_flash_attention_kernel_inline_seq_dim(lengths_ref, page_indices_ref, buffer_index_ref, step_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, sem, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None)[source]#
Pallas kernel for paged attention that loops over sequence blocks internally.
This kernel performs the same computation as paged_flash_attention_kernel but iterates over the sequence blocks (i) using an internal lax.fori_loop instead of having i as a program_id. The grid for this kernel is typically (num_cores, batch_size // b_step, num_kv_heads // h_step).
- Parameters
lengths_ref โ SMEM Ref to sequence lengths for each batch item.
page_indices_ref โ Ref to page indices mapping sequence positions to HBM pages.
buffer_index_ref โ SMEM Ref storing the current VMEM buffer index (0 or 1).
step_ref โ SMEM Ref storing the current step/block index being processed.
q_ref โ VMEM Ref to the query vector(s) for the current token.
k_pages_hbm_ref โ HBM Ref to the Key cache pages.
v_pages_hbm_ref โ HBM Ref to the Value cache pages.
o_ref โ VMEM Ref to store the computed output attention vector(s).
m_ref โ VMEM Ref to store the running maximum logit.
l_ref โ VMEM Ref to store the running sum of exp(logit - max_logit).
k_vmem_buffer โ VMEM Ref for the double buffer used to load Key pages.
v_vmem_buffer โ VMEM Ref for the double buffer used to load Value pages.
sem โ Pallas Ref for the semaphore used for async copy synchronization.
batch_size โ Total batch size.
pages_per_compute_block โ Number of KV cache pages processed per iteration.
pages_per_sequence โ Maximum number of pages allocated per sequence.
mask_value โ Value to use for masking attention logits.
attn_logits_soft_cap โ If not None, apply tanh capping to logits.
megacore_mode โ How to distribute work across TPU cores (โbatchโ or โkv_headโ).
- easydel.kernels.tpu_ops.paged_attention._forward_pallas.prefill_attention_impl(length_ref, page_indices_ref, buffer_index_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, out_ref, l_ref, m_ref, k_vmem_buffer, v_vmem_buffer, sem)[source]#
Pallas kernel implementation for paged attention prefill phase.
This kernel computes attention for a chunk of query tokens (part of the prompt) against the paged Key-Value cache built so far. It iterates through chunks of the KV cache, applying causal masking and using online softmax. Double buffering is used for loading KV cache chunks.
The grid for this kernel is typically (num_kv_heads,). It processes one query chunk for all associated attention heads within a KV head group.
- Parameters
length_ref โ SMEM Ref containing the total sequence length of the prompt.
page_indices_ref โ SMEM Ref containing the page indices for this sequence.
buffer_index_ref โ SMEM Ref storing the current VMEM buffer index (0 or 1).
q_ref โ VMEM Ref to the current chunk of query vectors.
k_pages_hbm_ref โ HBM Ref to the Key cache pages.
v_pages_hbm_ref โ HBM Ref to the Value cache pages.
out_ref โ VMEM Ref to store the computed output attention vectors for the chunk.
l_ref โ VMEM Ref to store the running sum part of online softmax.
m_ref โ VMEM Ref to store the running max logit part of online softmax.
k_vmem_buffer โ VMEM Ref for the double buffer used to load Key chunks.
v_vmem_buffer โ VMEM Ref for the double buffer used to load Value chunks.
sem โ Pallas Ref for the semaphore used for async copy synchronization.