Source code for easydel.layers.caching.ragged_page.cache

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing as tp

import jax
import jax.experimental
import jax.numpy as jnp
from eformer import common_types
from eformer import escale as es
from eformer.escale import PartitionAxis, PartitionManager
from eformer.jaximus import ImplicitArray
from eformer.loggings import get_logger
from eformer.mpric import DTYPE_TO_STRING_MAP
from eformer.pytree import auto_pytree, field
from jax.sharding import Mesh
from jax.sharding import NamedSharding as Ns
from jaxtyping import Array, Float, Int

from easydel.utils.helpers import check_bool_flag

from .._abstracts import BaseCache, BaseCacheMetadata, BaseCacheView
from .utils import kv_cache_update, kv_cache_update_jax

if tp.TYPE_CHECKING:
    from easydel.layers.quantization.quantizers import EasyQuantizer
else:
    EasyQuantizer = object

EMPTY = common_types.EMPTY
KV_HEAD = common_types.KV_HEAD
MODE_PREFILL = common_types.MODE_PREFILL

logger = get_logger(__name__)

PERMITTED_KV_KERNELS = check_bool_flag("PERMITTED_KV_KERNELS")


[docs]def cdiv(a: int, b: int) -> int: return (a + b - 1) // b
[docs]def previous_power_of_2(n: int) -> int: if n <= 0: return 0 return 1 << (n.bit_length() - 1)
[docs]def get_num_slices_per_kv_cache_update_page(page_size_bytes: int) -> int: num_slices_per_page = (16 * 1024 * 1024) // page_size_bytes assert num_slices_per_page > 0, "Number of slices should be positive" num_slices_per_page = previous_power_of_2(num_slices_per_page) if num_slices_per_page > 64: num_slices_per_page = 64 return num_slices_per_page
[docs]def get_dtype_packing(dtype: jnp.dtype) -> int: bits = jnp.finfo(dtype).bits if 32 % bits != 0: raise ValueError(f"The bit width must be divisible by 32, but got bits={bits}, dtype={{dtype}}") return 32 // bits
[docs]def align_to_multiple(value: int, multiple: int) -> int: return cdiv(value, multiple) * multiple
[docs]def get_page_size_bytes( page_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: jnp.dtype, ) -> int: """Returns the size in bytes of one page of the KV cache.""" padded_head_size = cdiv(head_size, 128) * 128 num_combined_kv_heads = num_kv_heads * 2 packing = get_dtype_packing(kv_cache_dtype) num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing kv_cache_dtype_bits = jnp.finfo(kv_cache_dtype).bits return page_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
[docs]def per_device_hbm_budget_bytes(util: float = 0.9, mode: str = "free", safety_margin: int = 256 << 20) -> int: budgets = [] for d in jax.local_devices(): try: s = d.memory_stats() except Exception: continue limit = s.get("bytes_limit") or s.get("bytes_reservable_limit") or s.get("bytes_total") used_in_use = s.get("bytes_in_use", 0) used_reserved = s.get("bytes_reserved", 0) used = max(used_in_use, used_reserved) if limit is None: continue free = max(0, int(limit) - int(used)) if mode == "free": usable = max(0, int(free * float(util)) - safety_margin) else: usable = max(0, int(int(limit) * float(util)) - int(used) - safety_margin) budgets.append(usable) return min(budgets) if budgets else 4 * (1024**3)
[docs]@auto_pytree class RaggedPagesCacheMetaData(BaseCacheMetadata): """ Metadata holding configuration parameters for the Paged Attention KV cache. This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from `BaseCacheMetadata`. """ num_hidden_layers: int = field(pytree_node=False) max_model_length: int = field(pytree_node=False) num_kv_heads: int = field(pytree_node=False) k_headdim: int = field(pytree_node=False) v_headdim: int = field(pytree_node=False) hbm_utilization: float = field(pytree_node=False, default=0.9) page_size: int = field(pytree_node=False, default=128) num_pages: int = field(pytree_node=False, default=-1) max_num_pages_per_req: int = field(pytree_node=False, default=-1) num_slices_per_kv_cache_update_page: int = field(pytree_node=False, default=-1) max_num_tokens: int = field(pytree_node=False, default=-1) max_num_reqs: int = field(pytree_node=False, default=-1) version: str | tp.Literal["v3", "v2"] = field(pytree_node=False, default="v3") _kvdtype_str: str = field(pytree_node=False, default="bf16") @staticmethod def _compute_free_hbm(mesh: Mesh, partition_manager: PartitionManager, hbm_utilization: float): kv_head_axis = partition_manager.paxis.kv_head_axis size = int(mesh.shape[kv_head_axis]) budget = per_device_hbm_budget_bytes(hbm_utilization, mode="free") available_alloc = budget * size logger.info(f"{kv_head_axis=} {size=} {budget=} {available_alloc=} {hbm_utilization=}") return available_alloc
[docs] @classmethod def create( cls, mesh: Mesh, partition_manager: PartitionManager, kvdtype: jnp.dtype, num_hidden_layers: int, num_kv_heads: int, max_model_length: int, kv_head_dim_size: int | None = None, k_headdim: int | None = None, v_headdim: int | None = None, hbm_utilization: float = 0.9, page_size: int = 128, version: tp.Literal["v3", "v2"] = "v3", ) -> RaggedPagesCacheMetaData: if k_headdim is None: assert kv_head_dim_size is not None, "Either `k_headdim` or `kv_head_dim_size` must be provided" k_headdim = kv_head_dim_size if v_headdim is None: assert kv_head_dim_size is not None, "Either `v_headdim` or `kv_head_dim_size` must be provided" v_headdim = kv_head_dim_size if num_hidden_layers <= 0: raise ValueError("`num_hidden_layers` must be positive") if num_kv_heads <= 0: raise ValueError("`num_kv_heads` must be positive") if kv_head_dim_size <= 0: raise ValueError("`kv_head_dim_size` must be positive") free = cls._compute_free_hbm(mesh=mesh, partition_manager=partition_manager, hbm_utilization=hbm_utilization) bytes_av = jnp.finfo(kvdtype).bits // 8 page_bytes = 2 * num_hidden_layers * page_size * num_kv_heads * kv_head_dim_size * bytes_av num_pages = int(free) // page_bytes logger.info( f"Creating PagesCacheMetadata with {num_pages=} {page_bytes=} " f"sequence_capacity={int((num_pages * page_size) / 1000)}K" ) assert version in ["v3", "v2"], f"got unknown version {version} it should be v3/v2." return cls( num_hidden_layers=num_hidden_layers, max_model_length=max_model_length, num_kv_heads=num_kv_heads, k_headdim=k_headdim, v_headdim=v_headdim, hbm_utilization=hbm_utilization, page_size=page_size, num_pages=num_pages, max_num_pages_per_req=cdiv(max_model_length, page_size), num_slices_per_kv_cache_update_page=get_num_slices_per_kv_cache_update_page( get_page_size_bytes( page_size=page_size, num_kv_heads=num_kv_heads, head_size=k_headdim, kv_cache_dtype=kvdtype, ) ), version=version, _kvdtype_str=DTYPE_TO_STRING_MAP[kvdtype], )
@property def kvdtype(self) -> jnp.dtype: from eformer.mpric import STRING_TO_DTYPE_MAP return STRING_TO_DTYPE_MAP[self._kvdtype_str] @property def kv_head_packing(self) -> int: return get_dtype_packing(self.kvdtype) @property def storage_num_combined_kv_heads(self) -> int: if self.k_headdim == 64: return align_to_multiple(self.num_kv_heads, self.kv_head_packing) return align_to_multiple(self.num_kv_heads * 2, self.kv_head_packing) @property def storage_num_kv_groups(self) -> int: return self.storage_num_combined_kv_heads // self.kv_head_packing @property def storage_head_dim(self) -> int: if self.k_headdim == 64: return 128 return align_to_multiple(self.k_headdim, 128)
[docs] def get_padded_num_slices( self, num_tokens: int | None = None, max_num_reqs: int | None = None, ) -> int: if num_tokens is None or num_tokens <= 0: num_tokens = self.max_num_tokens if self.max_num_tokens > 0 else self.max_model_length if max_num_reqs is None or max_num_reqs <= 0: max_num_reqs = self.max_num_reqs if self.max_num_reqs > 0 else self.get_max_num_seqs() padded_num_slices = 2 * max_num_reqs + num_tokens // self.page_size padded_num_slices = min(padded_num_slices, num_tokens) padded_num_slices = ( (padded_num_slices + self.num_slices_per_kv_cache_update_page - 1) // self.num_slices_per_kv_cache_update_page * self.num_slices_per_kv_cache_update_page ) return padded_num_slices
[docs] def get_max_num_seqs(self) -> int: num_page_per_req = cdiv(self.max_model_length, self.page_size) return 1024 * 1024 // 2 // num_page_per_req // 4
[docs] def get_shape_and_axes(self): if self.version == "v3": kv_pages_shape = ( self.num_pages, self.page_size, self.storage_num_kv_groups, self.kv_head_packing, self.storage_head_dim, ) axes = [common_types.EMPTY, common_types.EMPTY, common_types.HEAD, common_types.EMPTY, common_types.EMPTY] elif self.version == "v2": kv_pages_shape = ( self.num_pages, self.page_size, self.num_kv_heads * 2, self.k_headdim, ) axes = [common_types.EMPTY, common_types.EMPTY, common_types.HEAD, common_types.EMPTY] else: raise ValueError(f"got unknown version {self.version} it should be v3/v2.") return kv_pages_shape, axes
@property def is_v3(self): return self.version == "v3" @property def is_v2(self): return self.version == "v2"
[docs]@auto_pytree class RaggedPagesCacheView(BaseCacheView): """ Represents the view of the Paged Attention KV cache for a single transformer layer. It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from `BaseCacheView`. Attributes: metadata (RaggedPagesCacheMetaData): The static configuration metadata for the entire paged cache. layer_index (int): The index of the transformer layer this view corresponds to. kv_pages (tp.Union[cx.Array, ImplicitArray]): The tensor holding all key value pages for this layer. Shape: (num_pages, page_size, aligned_kv_groups, packing, aligned_head_dim). Can be a JAX array or an ImplicitArray if quantization is used. """ metadata: RaggedPagesCacheMetaData layer_index: int = field(pytree_node=False) kv_pages: ( Float[Array, "num_pages page_size storage_groups packing head_dim"] | Float[Array, "num_pages page_size kv_head_combined head_dim"] | ImplicitArray ) partition_manager: PartitionManager = field( pytree_node=False, default_factory=lambda: PartitionManager(PartitionAxis()), )
[docs] @classmethod def init( cls, mesh: Mesh, metadata: RaggedPagesCacheMetaData, layer_index: int, partition_manager: es.PartitionManager, quantizer: EasyQuantizer | None = None, ) -> RaggedPagesCacheView: """ Initializes the RaggedPagesCacheView for a specific layer. Allocates the `kv_pages` tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided. Args: mesh (Mesh): The JAX device mesh. dtype (jnp.dtype): The data type for the cache pages (e.g., jnp.bfloat16). metadata (RaggedPagesCacheMetaData): Static configuration for the cache. layer_index (int): The index of the layer this view is for. partition_manager (es.PartitionManager): Manages tensor sharding across the mesh. quantizer (tp.Optional["EasyQuantizer"]): Optional quantizer to apply to the pages. Returns: RaggedPagesCacheView: An initialized cache view for the specified layer. """ from easydel.layers.quantization.quantizers import EasyQuantizer quantizer = quantizer or EasyQuantizer(quantization_config=None) kv_pages_shape, axes = metadata.get_shape_and_axes() kv_pages_sharding = partition_manager.resolve(axes=axes, mode=common_types.MODE_PREFILL, shape=kv_pages_shape) kv_pages_sharding = Ns(mesh=mesh, spec=kv_pages_sharding) with jax.named_scope("easydel-paged-attention-cache-init"): kv_pages = quantizer(jnp.zeros(shape=kv_pages_shape, dtype=metadata.kvdtype, device=kv_pages_sharding)) return cls(metadata=metadata, layer_index=layer_index, kv_pages=kv_pages, partition_manager=partition_manager)
[docs] def concatenate_to_cache( self, key: Float[Array, "batch seq_len num_key_heads head_dim"], value: Float[Array, "batch seq_len num_value_heads head_dim"], cache_metadata: RaggedPagesMetadata, ) -> RaggedPagesCacheView: if self.metadata.is_v2: num_kv_heads = key.shape[2] head_size = key.shape[3] key = key.reshape(-1, num_kv_heads, head_size).astype(self.kv_pages.dtype) value = value.reshape(-1, num_kv_heads, head_size).astype(self.kv_pages.dtype) use_kernel = jax.default_backend() == "tpu" and PERMITTED_KV_KERNELS if head_size != 128 and use_kernel: use_kernel = False use_shardmap = True else: use_shardmap = use_kernel def _update_fn( kv: Float[Array, "num_tokens num_kv_heads_x2 head_dim"], slots: Int[Array, "num_tokens"], # noqa: F821 pages: Float[Array, "num_pages page_size num_kv_heads_x2 head_dim"], num_update_slices: Int[Array, ""], ) -> Float[Array, "num_pages page_size num_kv_heads_x2 head_dim"]: orgshape = pages.shape pages = pages.reshape(-1, *orgshape[2:]) if use_kernel: pages = kv_cache_update( kv, slots, pages, num_update_slices, page_size=cache_metadata.page_size, slices_per_processing_page=cache_metadata.num_slices_per_kv_cache_update_page, ) else: pages = kv_cache_update_jax( kv, slots, pages, num_update_slices, page_size=cache_metadata.page_size, ) return pages.reshape(*orgshape) if use_shardmap: resolve = self.partition_manager.resolve _update_fn = jax.shard_map( _update_fn, in_specs=( resolve([EMPTY, common_types.HEAD, EMPTY], mode=MODE_PREFILL), resolve([EMPTY, EMPTY], mode=MODE_PREFILL), resolve([EMPTY, EMPTY, common_types.HEAD, EMPTY], mode=MODE_PREFILL), resolve([EMPTY], mode=MODE_PREFILL), ), out_specs=resolve([EMPTY, EMPTY, common_types.HEAD, EMPTY], mode=MODE_PREFILL), mesh=es.get_incontext_mesh(), check_vma=False, ) kvs = jnp.stack([key, value], axis=2).reshape(-1, num_kv_heads * 2, head_size) kv_pages = _update_fn(kvs, cache_metadata.slot_mapping, self.kv_pages, cache_metadata.num_kv_update_slices) return self.replace(kv_pages=kv_pages) return self
[docs] def flattened_kv_pages(self) -> Float[Array, "num_pages page_size num_kv_heads_x2 head_dim"]: if self.metadata.is_v2: return self.kv_pages pages = self.kv_pages shape = pages.shape return pages.reshape(shape[0], shape[1], shape[2] * shape[3], shape[4])
@property def key_pages(self) -> Float[Array, "num_pages page_size num_kv_heads head_dim"]: flat = self.flattened_kv_pages() return flat[:, :, 0::2, :] @property def value_pages(self) -> Float[Array, "num_pages page_size num_kv_heads head_dim"]: flat = self.flattened_kv_pages() return flat[:, :, 1::2, :] def __repr__(self) -> str: return f"{self.__class__.__name__}(layer_index={self.layer_index}, kv_shape={self.key_pages.shape})" __str__ = __repr__
[docs]@auto_pytree class RaggedPagesCache(BaseCache): """ Represents the complete Paged Attention KV cache for all layers of a model. It holds a list of `RaggedPagesCacheView` objects, one for each layer. It inherits from `BaseCache`. Attributes: views (tp.List[RaggedPagesCacheView]): A list containing the cache view for each layer in the model. """ views: list[RaggedPagesCacheView] @property def metadata(self) -> RaggedPagesCacheMetaData | None: if self.views[-1] is None: return None return self.views[-1].metadata
[docs] @classmethod def init_cache( cls, mesh: Mesh, metadata: RaggedPagesCacheMetaData, partition_manager: es.PartitionManager, quantizer: EasyQuantizer | None = None, ) -> RaggedPagesCache: """ Initializes the entire RaggedPagesCache for all layers. Creates a list of `RaggedPagesCacheView` instances, one for each layer specified in the `metadata`, by calling `RaggedPagesCacheView.init` for each layer. Args: mesh (Mesh): The JAX device mesh. dtype (jnp.dtype): The data type for the cache pages. metadata (RaggedPagesCacheMetaData): Static configuration for the cache. partition_manager (es.PartitionManager): Manages tensor sharding. quantizer (tp.Optional["EasyQuantizer"]): Optional quantizer to apply. Returns: RaggedPagesCache: An initialized cache object containing views for all layers. """ views = [ RaggedPagesCacheView.init( mesh=mesh, metadata=metadata, quantizer=quantizer, layer_index=i, partition_manager=partition_manager, ) for i in range(metadata.num_hidden_layers) ] return cls(views=views)
[docs] def init_empty(self, *args, **kwargs) -> None: """Not typically used for RaggedPagesCache; returns None.""" return None
def __repr__(self) -> str: """Provides a string representation of the entire paged cache.""" idx = self.views[-1] try: kv_shape = idx.kv_pages.shape except AttributeError: kv_shape = "Uninitialized" return f"{self.__class__.__name__}(\n kv_pages={kv_shape},\n num_layers={len(self.views)},\n)" __str__ = __repr__
[docs]@auto_pytree(max_print_length=3000) class RaggedPagesMetadata: pages_tables: Int[Array, "max_num_reqs max_pages"] context_lens: Int[Array, "max_num_reqs"] # noqa: F821 query_start_loc: Int[Array, "max_num_reqs_plus_1"] # noqa: F821 num_seqs: Int[Array, "max_num_reqs"] # noqa: F821 slot_mapping: Int[Array, "num_tokens"] | None = None # noqa: F821 position_ids: Int[Array, "num_tokens"] | None = None # noqa: F821 request_distribution: Int[Array, "3"] | None = None num_kv_update_slices: Int[Array, "1"] | None = None version: str | tp.Literal["v3", "v2"] = field(pytree_node=False, default="v3") num_slices_per_kv_cache_update_page: int | None = field(pytree_node=False, default_factory=lambda: None) page_size: int = field(pytree_node=False, default=128) prefill_chunk_size: int = field(pytree_node=False, default=512)
[docs] @classmethod def create_empty( cls, num_tokens: int, max_num_reqs: int, max_pages: int, page_size: int = 128, version: tp.Literal["v3", "v2"] = "v3", ) -> RaggedPagesMetadata: """Create empty metadata with proper shapes.""" return cls( slot_mapping=jnp.zeros([num_tokens], dtype=jnp.int32) if version == "v2" else None, pages_tables=jnp.zeros((max_num_reqs, max_pages), dtype=jnp.int32), context_lens=jnp.zeros([max_num_reqs], dtype=jnp.int32), query_start_loc=jnp.zeros([max_num_reqs + 1], dtype=jnp.int32), position_ids=jnp.zeros([max_num_reqs], dtype=jnp.int32), num_seqs=jnp.zeros([max_num_reqs], dtype=jnp.int32), request_distribution=jnp.zeros((3,), dtype=jnp.int32) if version == "v3" else None, num_kv_update_slices=jnp.zeros((1,), dtype=jnp.int32) if version == "v2" else None, page_size=page_size, version=version, )