easydel.layers.caching.ragged_page.utils

easydel.layers.caching.ragged_page.utils#

Utility functions for paged attention cache operations.

This module provides low-level utilities for efficient paged KV cache updates, including both TPU-optimized kernel implementations and pure JAX fallbacks for other backends.

The paged attention mechanism divides the KV cache into fixed-size pages, enabling more efficient memory management and reducing fragmentation in long-context scenarios.

Key Functions:
  • cdiv: Ceiling division utility

  • kv_cache_update: TPU-optimized paged cache update

  • kv_cache_update_jax: Pure JAX implementation for compatibility

  • _kv_cache_update_kernel: Low-level TPU kernel

Optimizations:
  • Asynchronous DMA transfers on TPU

  • Vectorized memory operations

  • Efficient page-based updates

  • Minimal memory copies

easydel.layers.caching.ragged_page.utils.cdiv(a: int, v: int) int[source]#

Ceiling division: divide a by v and round up.

Calculates the ceiling of a/v using integer arithmetic, avoiding floating point operations for efficiency.

Parameters
  • a (int) – Numerator (e.g., total items)

  • v (int) – Denominator (e.g., items per page)

Returns

The ceiling of a/v

Return type

int

Example

>>> cdiv(10, 3)  # 10 items, 3 per page
4  # Need 4 pages
easydel.layers.caching.ragged_page.utils.kv_cache_update(new_kv_tokens: Float[Array, 'total_tokens num_combined_kv_heads head_dim'], slice_indices: Int[Array, '3 num_slices'], kv_cache_pages: Float[Array, 'total_cache_positions num_combined_kv_heads head_dim'], total_update_slices: Int[Array, ''], *, page_size: int = 32, slices_per_processing_page: int = 8) Float[Array, 'total_cache_positions num_combined_kv_heads head_dim'][source]#

TPU-optimized paged KV cache update using Pallas kernels.

Efficiently updates the KV cache with new tokens using hardware-accelerated DMA transfers and vectorized operations. The update is performed in pages to minimize memory fragmentation and improve cache locality.

This function: 1. Validates input dimensions and alignment 2. Configures VMEM scratch buffers for staging 3. Launches parallel Pallas kernels for updates 4. Returns updated cache pages

Parameters
  • new_kv_tokens (jax.Array) – New key-value tokens to insert. Shape: [total_tokens, num_combined_kv_heads, head_dim] where num_combined_kv_heads = 2 * num_kv_heads (keys + values)

  • slice_indices (jax.Array) – Mapping of update operations. Shape: [3, num_slices] where each column contains: - Row 0: Starting position in cache - Row 1: Starting position in new_kv_tokens - Row 2: Length of slice to copy

  • kv_cache_pages (jax.Array) – Existing cache pages to update. Shape: [total_pages * page_size, num_combined_kv_heads, head_dim]

  • total_update_slices (jax.Array) – Number of valid slices to process. Shape: [1] - scalar wrapped in array for XLA compatibility

  • page_size (int) – Number of tokens per cache page. Default: 32. Must be static for compilation.

  • slices_per_processing_page (int) – Slices processed per kernel invocation. Default: 8. Must divide slice_indices.shape[1] evenly.

Returns

Updated KV cache pages with same shape as kv_cache_pages.

Return type

jax.Array

Raises

AssertionError – If dimensions are incompatible or alignment is wrong.

Note

  • Requires TPU backend for hardware acceleration

  • head_dim must be divisible by 128 for alignment

  • Automatically falls back to JAX implementation on non-TPU

Example

>>> cache = kv_cache_update(
...     new_kv_tokens=new_kv,
...     slice_indices=indices,
...     kv_cache_pages=cache,
...     total_update_slices=jnp.array([10]),
...     page_size=32
... )
easydel.layers.caching.ragged_page.utils.kv_cache_update_jax(new_kv_tokens: Float[Array, 'total_tokens num_kv_heads head_dim'], slice_indices: Int[Array, '3 num_slices'], kv_cache_pages: Float[Array, 'total_cache_positions num_kv_heads head_dim'], total_update_slices: Int[Array, ''], *, page_size: int = 32) Float[Array, 'total_cache_positions num_kv_heads head_dim'][source]#

Pure JAX implementation of paged KV cache update.

Provides a portable fallback implementation using JAX operations instead of hardware-specific kernels. While slower than the TPU kernel version, this ensures compatibility across all backends.

The implementation uses dynamic slicing and scanning to update cache pages functionally, maintaining JAX’s immutability guarantees.

Algorithm: 1. Pad new tokens to page boundaries 2. For each slice in slice_indices:

  • Extract slice from new tokens

  • Create mask for partial updates

  • Merge with existing cache content

  • Update cache slice

  1. Return updated cache

Parameters
  • new_kv_tokens (jax.Array) – New key/value tokens to insert. Shape: [total_tokens, num_kv_heads, head_dim]

  • slice_indices (jax.Array) – Update mapping information. Shape: [3, num_slices] where each column contains: - Row 0: Cache starting position - Row 1: New tokens starting position - Row 2: Number of tokens to copy

  • kv_cache_pages (jax.Array) – Existing cache to update. Shape: [total_pages * page_size, num_kv_heads, head_dim]

  • total_update_slices (jax.Array) – Number of valid slices. Shape: [1] - wrapped scalar for XLA

  • page_size (int) – Tokens per cache page. Default: 32. Must be static for JIT compilation.

Returns

Updated cache with same shape as kv_cache_pages.

Return type

jax.Array

Note

This implementation is automatically used on non-TPU backends or when the TPU kernel is unavailable.

Example

>>> # Fallback for CPU/GPU
>>> updated_cache = kv_cache_update_jax(
...     new_kv_tokens=tokens,
...     slice_indices=indices,
...     kv_cache_pages=cache,
...     total_update_slices=jnp.array([5]),
...     page_size=32
... )