Source code for easydel.layers.caching.ragged_page.utils

"""Utility functions for paged attention cache operations.

This module provides low-level utilities for efficient paged KV cache
updates, including both TPU-optimized kernel implementations and
pure JAX fallbacks for other backends.

The paged attention mechanism divides the KV cache into fixed-size
pages, enabling more efficient memory management and reducing
fragmentation in long-context scenarios.

Key Functions:
    - cdiv: Ceiling division utility
    - kv_cache_update: TPU-optimized paged cache update
    - kv_cache_update_jax: Pure JAX implementation for compatibility
    - _kv_cache_update_kernel: Low-level TPU kernel

Optimizations:
    - Asynchronous DMA transfers on TPU
    - Vectorized memory operations
    - Efficient page-based updates
    - Minimal memory copies
"""

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jaxtyping import Array, Float, Int

from easydel.utils.compiling_utils import ejit


[docs]def cdiv(a: int, v: int) -> int: """Ceiling division: divide a by v and round up. Calculates the ceiling of a/v using integer arithmetic, avoiding floating point operations for efficiency. Args: a (int): Numerator (e.g., total items) v (int): Denominator (e.g., items per page) Returns: int: The ceiling of a/v Example: >>> cdiv(10, 3) # 10 items, 3 per page 4 # Need 4 pages """ return (a + v - 1) // v
def _kv_cache_update_kernel( slice_indices_ref, # Pallas reference type new_kv_tokens_hbm_ref, # Pallas reference type kv_cache_pages_hbm_ref, # Pallas reference type _, # Placeholder vmem_scratch_buffer, # VMEM buffer dma_semaphore, # Semaphore ) -> None: """Low-level TPU kernel for paged KV cache updates. Implements a two-phase DMA transfer strategy: 1. Copy new KV tokens from HBM to VMEM scratch buffer 2. Copy from scratch buffer to final cache pages This approach maximizes TPU memory bandwidth utilization by overlapping computation with asynchronous DMA transfers. Args: slice_indices_ref: Reference to slice mapping information. Shape: [3, num_slices] containing (cache_pos, new_kv_pos, length) new_kv_tokens_hbm_ref: Reference to new KV tokens in HBM. kv_cache_pages_hbm_ref: Reference to cache pages in HBM. _: Unused placeholder argument. vmem_scratch_buffer: VMEM scratch space for staging transfers. dma_semaphore: Semaphore for DMA synchronization. Note: This kernel is called by Pallas and executes on TPU. Each invocation processes one 'processing page' of slices. """ pending_async_copies = [] current_page_id = pl.program_id(0) slices_per_processing_page = vmem_scratch_buffer.shape[0] # First phase: Copy from new KV tokens to scratch buffer for slice_idx in range(slices_per_processing_page): global_slice_idx = slice_idx + current_page_id * slices_per_processing_page new_kv_start_pos = slice_indices_ref[1, global_slice_idx] slice_length = slice_indices_ref[2, global_slice_idx] copy_operation = pltpu.make_async_copy( new_kv_tokens_hbm_ref.at[pl.ds(new_kv_start_pos, slice_length), ...], vmem_scratch_buffer.at[slice_idx, pl.ds(0, slice_length), ...], dma_semaphore, ) copy_operation.start() pending_async_copies.append(copy_operation) # Wait for all copies to complete for copy_operation in pending_async_copies: copy_operation.wait() # Second phase: Copy from scratch buffer to KV cache pending_async_copies.clear() for slice_idx in range(slices_per_processing_page): global_slice_idx = slice_idx + current_page_id * slices_per_processing_page kv_cache_start_pos = slice_indices_ref[0, global_slice_idx] slice_length = slice_indices_ref[2, global_slice_idx] copy_operation = pltpu.make_async_copy( vmem_scratch_buffer.at[slice_idx, pl.ds(0, slice_length), ...], kv_cache_pages_hbm_ref.at[pl.ds(kv_cache_start_pos, slice_length), ...], dma_semaphore, ) copy_operation.start() pending_async_copies.append(copy_operation) # Wait for all final copies to complete for copy_operation in pending_async_copies: copy_operation.wait()
[docs]@ejit(static_argnames=["page_size", "slices_per_processing_page"]) def kv_cache_update( new_kv_tokens: Float[Array, "total_tokens num_combined_kv_heads head_dim"], slice_indices: Int[Array, "3 num_slices"], kv_cache_pages: Float[Array, "total_cache_positions num_combined_kv_heads head_dim"], total_update_slices: Int[Array, ""], *, page_size: int = 32, slices_per_processing_page: int = 8, ) -> Float[Array, "total_cache_positions num_combined_kv_heads head_dim"]: """TPU-optimized paged KV cache update using Pallas kernels. Efficiently updates the KV cache with new tokens using hardware-accelerated DMA transfers and vectorized operations. The update is performed in pages to minimize memory fragmentation and improve cache locality. This function: 1. Validates input dimensions and alignment 2. Configures VMEM scratch buffers for staging 3. Launches parallel Pallas kernels for updates 4. Returns updated cache pages Args: new_kv_tokens (jax.Array): New key-value tokens to insert. Shape: [total_tokens, num_combined_kv_heads, head_dim] where num_combined_kv_heads = 2 * num_kv_heads (keys + values) slice_indices (jax.Array): Mapping of update operations. Shape: [3, num_slices] where each column contains: - Row 0: Starting position in cache - Row 1: Starting position in new_kv_tokens - Row 2: Length of slice to copy kv_cache_pages (jax.Array): Existing cache pages to update. Shape: [total_pages * page_size, num_combined_kv_heads, head_dim] total_update_slices (jax.Array): Number of valid slices to process. Shape: [1] - scalar wrapped in array for XLA compatibility page_size (int): Number of tokens per cache page. Default: 32. Must be static for compilation. slices_per_processing_page (int): Slices processed per kernel invocation. Default: 8. Must divide slice_indices.shape[1] evenly. Returns: jax.Array: Updated KV cache pages with same shape as kv_cache_pages. Raises: AssertionError: If dimensions are incompatible or alignment is wrong. Note: - Requires TPU backend for hardware acceleration - head_dim must be divisible by 128 for alignment - Automatically falls back to JAX implementation on non-TPU Example: >>> cache = kv_cache_update( ... new_kv_tokens=new_kv, ... slice_indices=indices, ... kv_cache_pages=cache, ... total_update_slices=jnp.array([10]), ... page_size=32 ... ) """ assert slice_indices.shape[1] % slices_per_processing_page == 0, ( f"{slices_per_processing_page=}, {slice_indices.shape[1]=}" ) _, num_kv_heads, head_dimension = new_kv_tokens.shape assert kv_cache_pages.shape[1] == num_kv_heads assert kv_cache_pages.shape[2] == head_dimension, f"{kv_cache_pages.shape[2]=}!={head_dimension=}" assert head_dimension % 128 == 0 prefetch_scalars = [slice_indices] vmem_scratch_buffer = pltpu.VMEM( (slices_per_processing_page, page_size, num_kv_heads, head_dimension), new_kv_tokens.dtype ) pallas_kernel = pl.pallas_call( _kv_cache_update_kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=len(prefetch_scalars), in_specs=[ pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], out_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], grid=(cdiv(total_update_slices[0], slices_per_processing_page),), scratch_shapes=[vmem_scratch_buffer, pltpu.SemaphoreType.DMA], ), out_shape=[jax.ShapeDtypeStruct(kv_cache_pages.shape, dtype=kv_cache_pages.dtype)], input_output_aliases={len(prefetch_scalars) + 1: 0}, ) return pallas_kernel(*prefetch_scalars, new_kv_tokens, kv_cache_pages)[0]
[docs]@ejit(static_argnames=["page_size"]) def kv_cache_update_jax( new_kv_tokens: Float[Array, "total_tokens num_kv_heads head_dim"], slice_indices: Int[Array, "3 num_slices"], kv_cache_pages: Float[Array, "total_cache_positions num_kv_heads head_dim"], total_update_slices: Int[Array, ""], *, page_size: int = 32, ) -> Float[Array, "total_cache_positions num_kv_heads head_dim"]: """Pure JAX implementation of paged KV cache update. Provides a portable fallback implementation using JAX operations instead of hardware-specific kernels. While slower than the TPU kernel version, this ensures compatibility across all backends. The implementation uses dynamic slicing and scanning to update cache pages functionally, maintaining JAX's immutability guarantees. Algorithm: 1. Pad new tokens to page boundaries 2. For each slice in slice_indices: - Extract slice from new tokens - Create mask for partial updates - Merge with existing cache content - Update cache slice 3. Return updated cache Args: new_kv_tokens (jax.Array): New key/value tokens to insert. Shape: [total_tokens, num_kv_heads, head_dim] slice_indices (jax.Array): Update mapping information. Shape: [3, num_slices] where each column contains: - Row 0: Cache starting position - Row 1: New tokens starting position - Row 2: Number of tokens to copy kv_cache_pages (jax.Array): Existing cache to update. Shape: [total_pages * page_size, num_kv_heads, head_dim] total_update_slices (jax.Array): Number of valid slices. Shape: [1] - wrapped scalar for XLA page_size (int): Tokens per cache page. Default: 32. Must be static for JIT compilation. Returns: jax.Array: Updated cache with same shape as kv_cache_pages. Note: This implementation is automatically used on non-TPU backends or when the TPU kernel is unavailable. Example: >>> # Fallback for CPU/GPU >>> updated_cache = kv_cache_update_jax( ... new_kv_tokens=tokens, ... slice_indices=indices, ... kv_cache_pages=cache, ... total_update_slices=jnp.array([5]), ... page_size=32 ... ) """ num_valid_slices = total_update_slices[0] padded_new_kv = jnp.pad(new_kv_tokens, [(0, page_size), (0, 0), (0, 0)], mode="constant") def update_single_slice( cache: Float[Array, "total_cache_positions num_kv_heads head_dim"], slice_idx: int ) -> Float[Array, "total_cache_positions num_kv_heads head_dim"]: """Update cache with a single slice of new tokens. Performs a masked update of a cache slice, handling partial page updates correctly by preserving unmodified elements. Args: cache: Current cache state. slice_idx: Index of slice to process. Returns: Updated cache array. """ cache_start_pos = slice_indices[0, slice_idx] new_kv_start_pos = slice_indices[1, slice_idx] actual_length = slice_indices[2, slice_idx] new_slice = jax.lax.dynamic_slice( padded_new_kv, (new_kv_start_pos, 0, 0), (page_size, padded_new_kv.shape[1], padded_new_kv.shape[2]), ) current_cache_slice = jax.lax.dynamic_slice( cache, (cache_start_pos, 0, 0), (page_size, cache.shape[1], cache.shape[2]), ) mask = jnp.arange(page_size)[:, None, None] < actual_length updated_slice = jnp.where(mask, new_slice, current_cache_slice) updated_cache = jax.lax.dynamic_update_slice(cache, updated_slice, (cache_start_pos, 0, 0)) return updated_cache def scan_fn( cache: Float[Array, "total_cache_positions num_kv_heads head_dim"], slice_idx: Int[Array, ""] ) -> tuple[Float[Array, "total_cache_positions num_kv_heads head_dim"], None]: """Scan function for iterating over cache slices. Conditionally updates cache based on slice validity. Args: cache: Current cache state. slice_idx: Current slice index. Returns: tuple: (updated_cache, None) """ updated_cache = jax.lax.cond( slice_idx < num_valid_slices, lambda c: update_single_slice(c, slice_idx), lambda c: c, cache, ) return updated_cache, None return jax.lax.scan(scan_fn, kv_cache_pages, jnp.arange(slice_indices.shape[1]))[0]