easydel.layers.caching.ragged_page.utils#
Utility functions for paged attention cache operations.
This module provides low-level utilities for efficient paged KV cache updates, including both TPU-optimized kernel implementations and pure JAX fallbacks for other backends.
The paged attention mechanism divides the KV cache into fixed-size pages, enabling more efficient memory management and reducing fragmentation in long-context scenarios.
- Key Functions:
cdiv: Ceiling division utility
kv_cache_update: TPU-optimized paged cache update
kv_cache_update_jax: Pure JAX implementation for compatibility
_kv_cache_update_kernel: Low-level TPU kernel
- Optimizations:
Asynchronous DMA transfers on TPU
Vectorized memory operations
Efficient page-based updates
Minimal memory copies
- easydel.layers.caching.ragged_page.utils.cdiv(a: int, v: int) int[source]#
Ceiling division: divide a by v and round up.
Calculates the ceiling of a/v using integer arithmetic, avoiding floating point operations for efficiency.
- Parameters
a (int) – Numerator (e.g., total items)
v (int) – Denominator (e.g., items per page)
- Returns
The ceiling of a/v
- Return type
int
Example
>>> cdiv(10, 3) # 10 items, 3 per page 4 # Need 4 pages
- easydel.layers.caching.ragged_page.utils.kv_cache_update(new_kv_tokens: Float[Array, 'total_tokens num_combined_kv_heads head_dim'], slice_indices: Int[Array, '3 num_slices'], kv_cache_pages: Float[Array, 'total_cache_positions num_combined_kv_heads head_dim'], total_update_slices: Int[Array, ''], *, page_size: int = 32, slices_per_processing_page: int = 8) Float[Array, 'total_cache_positions num_combined_kv_heads head_dim'][source]#
TPU-optimized paged KV cache update using Pallas kernels.
Efficiently updates the KV cache with new tokens using hardware-accelerated DMA transfers and vectorized operations. The update is performed in pages to minimize memory fragmentation and improve cache locality.
This function: 1. Validates input dimensions and alignment 2. Configures VMEM scratch buffers for staging 3. Launches parallel Pallas kernels for updates 4. Returns updated cache pages
- Parameters
new_kv_tokens (jax.Array) – New key-value tokens to insert. Shape: [total_tokens, num_combined_kv_heads, head_dim] where num_combined_kv_heads = 2 * num_kv_heads (keys + values)
slice_indices (jax.Array) – Mapping of update operations. Shape: [3, num_slices] where each column contains: - Row 0: Starting position in cache - Row 1: Starting position in new_kv_tokens - Row 2: Length of slice to copy
kv_cache_pages (jax.Array) – Existing cache pages to update. Shape: [total_pages * page_size, num_combined_kv_heads, head_dim]
total_update_slices (jax.Array) – Number of valid slices to process. Shape: [1] - scalar wrapped in array for XLA compatibility
page_size (int) – Number of tokens per cache page. Default: 32. Must be static for compilation.
slices_per_processing_page (int) – Slices processed per kernel invocation. Default: 8. Must divide slice_indices.shape[1] evenly.
- Returns
Updated KV cache pages with same shape as kv_cache_pages.
- Return type
- Raises
AssertionError – If dimensions are incompatible or alignment is wrong.
Note
Requires TPU backend for hardware acceleration
head_dim must be divisible by 128 for alignment
Automatically falls back to JAX implementation on non-TPU
Example
>>> cache = kv_cache_update( ... new_kv_tokens=new_kv, ... slice_indices=indices, ... kv_cache_pages=cache, ... total_update_slices=jnp.array([10]), ... page_size=32 ... )
- easydel.layers.caching.ragged_page.utils.kv_cache_update_jax(new_kv_tokens: Float[Array, 'total_tokens num_kv_heads head_dim'], slice_indices: Int[Array, '3 num_slices'], kv_cache_pages: Float[Array, 'total_cache_positions num_kv_heads head_dim'], total_update_slices: Int[Array, ''], *, page_size: int = 32) Float[Array, 'total_cache_positions num_kv_heads head_dim'][source]#
Pure JAX implementation of paged KV cache update.
Provides a portable fallback implementation using JAX operations instead of hardware-specific kernels. While slower than the TPU kernel version, this ensures compatibility across all backends.
The implementation uses dynamic slicing and scanning to update cache pages functionally, maintaining JAX’s immutability guarantees.
Algorithm: 1. Pad new tokens to page boundaries 2. For each slice in slice_indices:
Extract slice from new tokens
Create mask for partial updates
Merge with existing cache content
Update cache slice
Return updated cache
- Parameters
new_kv_tokens (jax.Array) – New key/value tokens to insert. Shape: [total_tokens, num_kv_heads, head_dim]
slice_indices (jax.Array) – Update mapping information. Shape: [3, num_slices] where each column contains: - Row 0: Cache starting position - Row 1: New tokens starting position - Row 2: Number of tokens to copy
kv_cache_pages (jax.Array) – Existing cache to update. Shape: [total_pages * page_size, num_kv_heads, head_dim]
total_update_slices (jax.Array) – Number of valid slices. Shape: [1] - wrapped scalar for XLA
page_size (int) – Tokens per cache page. Default: 32. Must be static for JIT compilation.
- Returns
Updated cache with same shape as kv_cache_pages.
- Return type
Note
This implementation is automatically used on non-TPU backends or when the TPU kernel is unavailable.
Example
>>> # Fallback for CPU/GPU >>> updated_cache = kv_cache_update_jax( ... new_kv_tokens=tokens, ... slice_indices=indices, ... kv_cache_pages=cache, ... total_update_slices=jnp.array([5]), ... page_size=32 ... )