# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import typing as tp
import chex as cx
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from eformer import common_types
from eformer import escale as es
from eformer.jaximus import ImplicitArray
from eformer.pytree import auto_pytree
from jax.sharding import Mesh
from jax.sharding import NamedSharding as Ns
from jax.sharding import PartitionSpec as Ps
from .._abstracts import (
BaseCache,
BaseCacheMetadata,
BaseCacheView,
)
if tp.TYPE_CHECKING:
from easydel.layers.quantization.quantizers import EasyQuantizer
else:
EasyQuantizer = object
[docs]@auto_pytree
class PagedAttentionCacheMetaData(BaseCacheMetadata):
"""
Metadata holding configuration parameters for the Paged Attention KV cache.
This class stores static configuration details required to initialize and manage
a paged KV cache, such as dimensions, page sizes, and resource utilization hints.
It inherits from `BaseCacheMetadata`.
Attributes:
batch_size (int): The maximum number of sequences processed concurrently during decoding.
num_hidden_layers (int): The total number of transformer layers in the model.
num_pages_per_layer (int): The total number of physical memory pages allocated
for the KV cache per layer across all sequences. This is calculated based on
available memory and `hbm_utilization`.
num_pages_per_sequence (int): The maximum number of pages a single sequence
can occupy, determined by `max_sequences` and `page_size`.
max_sequences (int): The maximum sequence length supported by the cache allocation.
page_size (int): The number of tokens stored per page in the KV cache.
num_kv_heads (int): The number of key/value heads in the attention mechanism.
kv_head_dim_size (int): The dimension size of each key/value head.
hbm_utilization (float): The target fraction of available High Bandwidth Memory (HBM)
to be utilized for the KV cache pages. Should be between 0.0 and 1.0.
"""
batch_size: int
num_hidden_layers: int
num_pages_per_layer: int
num_pages_per_sequence: int
max_sequences: int
page_size: int
num_kv_heads: int
kv_head_dim_size: int
hbm_utilization: float
@staticmethod
def _usable_hbm(hbm_utilization: float, mesh: Mesh) -> int:
"""
Calculates the usable HBM in bytes based on utilization factor and mesh.
(Internal helper method).
"""
per_device_memory_stats = jax.devices()[0].memory_stats()
limit = per_device_memory_stats["bytes_reservable_limit"]
used = per_device_memory_stats["bytes_in_use"]
return (int(limit * hbm_utilization) - used) * mesh.devices.size
[docs] @classmethod
def create(
cls,
mesh: Mesh,
batch_size: int,
num_hidden_layers: int,
max_sequences: int,
page_size: int,
num_kv_heads: int,
kv_head_dim_size: int,
hbm_utilization: float,
dtype: jnp.dtype = jnp.bfloat16,
) -> "PagedAttentionCacheMetaData":
"""
Factory method to create and initialize a PagedAttentionCacheMetaData instance.
Calculates derived values like `num_pages_per_layer` and `num_pages_per_sequence`
based on the provided parameters and estimated available memory.
Args:
mesh (Mesh): The JAX device mesh.
batch_size (int): Maximum concurrent sequences for decode.
num_hidden_layers (int): Number of transformer layers.
max_sequences (int): Maximum supported sequence length.
page_size (int): Number of tokens per cache page.
num_kv_heads (int): Number of KV heads.
kv_head_dim_size (int): Dimension of each KV head.
hbm_utilization (float): Target HBM utilization fraction (0.0 to 1.0).
dtype (jnp.dtype): Data type used for cache size calculation.
Returns:
PagedAttentionCacheMetaData: An initialized metadata object.
Raises:
ValueError: If input parameters are invalid (e.g., non-positive dimensions,
invalid utilization factor).
"""
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
if num_hidden_layers <= 0:
raise ValueError("`num_hidden_layers` must be positive")
if max_sequences <= 0:
raise ValueError("`max_sequences` must be positive")
if page_size <= 0:
raise ValueError("`page_size` must be positive")
if num_kv_heads <= 0:
raise ValueError("`num_kv_heads` must be positive")
if kv_head_dim_size <= 0:
raise ValueError("`kv_head_dim_size` must be positive")
if not (0.0 < hbm_utilization < 1.0):
raise ValueError("`hbm_utilization` must be positive float value in range 0~1")
hbm_bytes = cls._usable_hbm(hbm_utilization, mesh)
item_size = np.dtype(dtype).itemsize
per_kv_bytes = num_kv_heads * kv_head_dim_size * item_size * 2
num_pages_per_layer = (hbm_bytes // num_hidden_layers) // (page_size * per_kv_bytes)
num_pages_per_sequence = math.ceil(max_sequences / page_size)
return cls(
batch_size=batch_size,
num_hidden_layers=num_hidden_layers,
num_pages_per_layer=num_pages_per_layer,
num_pages_per_sequence=num_pages_per_sequence,
max_sequences=max_sequences,
page_size=page_size,
num_kv_heads=num_kv_heads,
kv_head_dim_size=kv_head_dim_size,
hbm_utilization=hbm_utilization,
)
[docs]@auto_pytree
class PagedAttentionCacheView(BaseCacheView):
"""
Represents the view of the Paged Attention KV cache for a single transformer layer.
It holds references to the physical key and value pages allocated for this layer
and the associated metadata. It provides methods to write new key/value pairs
into the correct pages based on runtime metadata. It inherits from `BaseCacheView`.
Attributes:
metadata (PagedAttentionCacheMetaData): The static configuration metadata for the
entire paged cache.
layer_index (int): The index of the transformer layer this view corresponds to.
key_pages (tp.Union[cx.Array, ImplicitArray]): The tensor holding all key pages for this layer.
Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size).
Can be a JAX array or an ImplicitArray if quantization is used.
value_pages (tp.Union[cx.Array, ImplicitArray]): The tensor holding all value pages for this layer.
Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size).
Can be a JAX array or an ImplicitArray if quantization is used.
"""
metadata: PagedAttentionCacheMetaData
layer_index: int
key_pages: tp.Union[cx.Array, ImplicitArray]
value_pages: tp.Union[cx.Array, ImplicitArray]
[docs] @classmethod
def init(
cls,
mesh: Mesh,
dtype: jnp.dtype,
metadata: PagedAttentionCacheMetaData,
layer_index: int,
partition_manager: es.PartitionManager,
quantizer: tp.Optional["EasyQuantizer"] = None,
):
"""
Initializes the PagedAttentionCacheView for a specific layer.
Allocates the `key_pages` and `value_pages` tensors with the appropriate
shape, dtype, and sharding based on the provided metadata and partition manager.
Optionally applies quantization if a quantizer is provided.
Args:
mesh (Mesh): The JAX device mesh.
dtype (jnp.dtype): The data type for the cache pages (e.g., jnp.bfloat16).
metadata (PagedAttentionCacheMetaData): Static configuration for the cache.
layer_index (int): The index of the layer this view is for.
partition_manager (es.PartitionManager): Manages tensor sharding across the mesh.
quantizer (tp.Optional["EasyQuantizer"]): Optional quantizer to apply to the pages.
Returns:
PagedAttentionCacheView: An initialized cache view for the specified layer.
"""
from easydel.infra.etils import EasyDeLQuantizationMethods
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = quantizer or EasyQuantizer(EasyDeLQuantizationMethods.NONE)
kv_pages_shape = (
metadata.num_kv_heads,
metadata.num_pages_per_layer,
metadata.page_size,
metadata.kv_head_dim_size,
)
kv_pages_sharding = partition_manager.resolve(
[
common_types.HEAD,
common_types.EMPTY,
common_types.EMPTY,
common_types.EMPTY,
],
mode=common_types.MODE_PREFILL,
shape=kv_pages_shape,
)
kv_pages_sharding = Ns(mesh=mesh, spec=kv_pages_sharding)
with jax.named_scope("easydel-paged-attention-cache-init"):
key_pages = jnp.zeros(
shape=kv_pages_shape,
dtype=dtype,
device=kv_pages_sharding,
)
value_pages = jnp.zeros(
shape=kv_pages_shape,
dtype=dtype,
device=kv_pages_sharding,
)
key_pages = quantizer(key_pages)
value_pages = quantizer(value_pages)
return cls(
metadata=metadata,
layer_index=layer_index,
key_pages=key_pages,
value_pages=value_pages,
)
[docs] def concatenate_to_cache(self, *args, **kwargs):
"""
Concatenation is not applicable for Paged Attention.
Raises NotImplementedError.
"""
raise NotImplementedError()
[docs] def write_prefill_to_cache(
self,
key: cx.Array,
value: cx.Array,
metadata: PagedAttentionMetadata,
):
"""
Writes the key/value pairs from a prefill step into the appropriate cache pages.
Uses the `prefill_page_table` from the runtime `metadata` to determine which
physical pages (`key_pages`, `value_pages`) correspond to the logical pages
of the prefill sequence. It transposes and reshapes the input key/value tensors
and uses `jax.lax.dynamic_update_slice_in_dim` within a `while_loop` to update
the relevant pages.
Args:
key (cx.Array): Key tensor for the prefill sequence. Shape
(padded_prefill_len, num_kv_heads, kv_head_dim_size).
value (cx.Array): Value tensor for the prefill sequence. Shape
(padded_prefill_len, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata): Runtime metadata containing the
`prefill_length` and `prefill_page_table`.
Returns:
PagedAttentionCacheView: Returns `self` after updating the pages.
"""
padded_prefill_len = key.shape[0]
kv_heads = self.key_pages.shape[0]
page_size = self.key_pages.shape[2]
num_pages = padded_prefill_len // page_size
num_pages = jnp.where(num_pages < 1, 1, num_pages)
num_active_pages, reminder = jnp.divmod(metadata.prefill_length, page_size)
num_active_pages += jnp.where(reminder > 0, 1, 0)
head_dim = self.key_pages.shape[-1]
key = (
key.transpose((1, 0, 2))
.reshape((kv_heads, -1, page_size, head_dim))
.astype(self.key_pages.dtype)
)
value = (
value.transpose((1, 0, 2))
.reshape((kv_heads, -1, page_size, head_dim))
.astype(self.value_pages.dtype)
)
def update_cond(carry):
_, idx = carry
return idx < num_active_pages
def per_page_update(carry):
(kp, vp), idx = carry
page_k = key[:, idx, :, :][:, None, :, :]
page_v = value[:, idx, :, :][:, None, :, :]
mapped_idx = metadata.prefill_page_table[idx]
kp = jax.lax.dynamic_update_slice_in_dim(
kp,
page_k,
mapped_idx,
axis=1,
)
vp = jax.lax.dynamic_update_slice_in_dim(
vp,
page_v,
mapped_idx,
axis=1,
)
idx += 1
return (kp, vp), idx
idx = 0
(self.key_pages, self.value_pages), idx = jax.lax.while_loop(
update_cond,
per_page_update,
((self.key_pages, self.value_pages), idx),
)
return self
[docs] def write_decodes_to_cache(
self,
key: cx.Array,
value: cx.Array,
metadata: PagedAttentionMetadata,
):
"""
Writes the key/value pairs from a decode step into the appropriate cache pages.
Uses the `decodes_position` and `decodes_page_table` from the runtime `metadata`
to calculate the exact page index and offset within that page where the new
key/value pair for each sequence in the batch should be written. It reshapes
the cache pages and input keys/values for efficient scattered updates using
`.at[...].set(...)`.
Args:
key (cx.Array): Key tensor for the decode tokens. Shape
(batch_size, num_kv_heads, kv_head_dim_size).
value (cx.Array): Value tensor for the decode tokens. Shape
(batch_size, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata): Runtime metadata containing
`decodes_position` and `decodes_page_table`.
Returns:
PagedAttentionCacheView: Returns `self` after updating the pages.
"""
key = key.transpose((1, 0, 2))
value = value.transpose((1, 0, 2))
key = key.astype(self.key_pages.dtype)
value = value.astype(self.value_pages.dtype)
num_tokens = key.shape[1]
kv_heads, num_pages, page_size, head_dim = self.key_pages.shape
page_idx, offset = jnp.divmod(metadata.decodes_position, page_size)
page_to_update = metadata.decodes_page_table[jnp.arange(0, num_tokens), page_idx]
mapped_page_to_update = page_to_update * page_size + offset
mapped_page_to_update = jnp.tile(mapped_page_to_update, kv_heads)
kv_heads_axis_stride = (
jnp.repeat(jnp.arange(0, kv_heads), num_tokens) * num_pages * page_size
)
mapped_page_to_update = kv_heads_axis_stride + mapped_page_to_update
key = key.reshape(-1, head_dim)
value = value.reshape(-1, head_dim)
self.key_pages = self.key_pages.reshape(-1, head_dim)
self.value_pages = self.value_pages.reshape(-1, head_dim)
self.key_pages = self.key_pages.at[mapped_page_to_update, :].set(key)
self.value_pages = self.value_pages.at[mapped_page_to_update, :].set(value)
self.key_pages = self.key_pages.reshape(
kv_heads,
num_pages,
page_size,
head_dim,
)
self.value_pages = self.value_pages.reshape(
kv_heads,
num_pages,
page_size,
head_dim,
)
return self
def __repr__(self):
return f"{self.__class__.__name__}(layer_index={self.layer_index}, kv_shape={self.key_pages.shape})"
__str__ = __repr__
[docs]@auto_pytree
class PagedAttentionCache(BaseCache):
"""
Represents the complete Paged Attention KV cache for all layers of a model.
It holds a list of `PagedAttentionCacheView` objects, one for each layer.
It inherits from `BaseCache`.
Attributes:
views (tp.List[PagedAttentionCacheView]): A list containing the cache view
for each layer in the model.
"""
views: tp.List[PagedAttentionCacheView]
[docs] @classmethod
def init_cache(
cls,
mesh: Mesh,
dtype: jnp.dtype,
metadata: PagedAttentionCacheMetaData,
partition_manager: es.PartitionManager,
quantizer: tp.Optional["EasyQuantizer"] = None,
):
"""
Initializes the entire PagedAttentionCache for all layers.
Creates a list of `PagedAttentionCacheView` instances, one for each layer
specified in the `metadata`, by calling `PagedAttentionCacheView.init` for each layer.
Args:
mesh (Mesh): The JAX device mesh.
dtype (jnp.dtype): The data type for the cache pages.
metadata (PagedAttentionCacheMetaData): Static configuration for the cache.
partition_manager (es.PartitionManager): Manages tensor sharding.
quantizer (tp.Optional["EasyQuantizer"]): Optional quantizer to apply.
Returns:
PagedAttentionCache: An initialized cache object containing views for all layers.
"""
views = [
PagedAttentionCacheView.init(
mesh=mesh,
dtype=dtype,
metadata=metadata,
quantizer=quantizer,
layer_index=i,
partition_manager=partition_manager,
)
for i in range(metadata.num_hidden_layers)
]
return cls(views=views)
[docs] def init_empty(self, *args, **kwargs):
"""Not typically used for PagedAttentionCache; returns None."""
return None
def __repr__(self):
"""Provides a string representation of the entire paged cache."""
idx = self.views[-1]
try:
k_shape = idx.key_pages.shape
v_shape = idx.value_pages.shape
except AttributeError:
k_shape = "Uninitialized"
v_shape = "Uninitialized"
return (
f"{self.__class__.__name__}(\n"
f" key_pages={k_shape},\n"
f" value_pages={v_shape},\n"
f" num_layers={len(self.views)},\n"
")"
)
__str__ = __repr__
[docs]@auto_pytree
class PagedAttentionMetadata:
"""
Runtime metadata required for performing a Paged Attention computation step.
This object holds the necessary information for a single forward pass of the
paged attention mechanism, distinguishing between prefill and decode steps
and providing the mappings (page tables) from logical sequence positions to
physical cache pages.
Attributes:
prefill_length (jax.Array): Scalar JAX array containing the actual length of the
prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.
prefill_position (jax.Array): JAX array of positions for the prefill tokens.
Shape (padded_prompt_length,). Empty shape () if not in prefill.
prefill_page_table (jax.Array): JAX array mapping logical page indices of the
prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,).
Empty shape () if not in prefill.
decodes_position (jax.Array): JAX array containing the current sequence position
(or length - 1) for each sequence in the decode batch. Shape (batch_size,).
Empty shape () if not in decode.
decodes_page_table (jax.Array): JAX array mapping logical page indices to physical
page indices for each sequence in the decode batch.
Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.
"""
prefill_length: jax.Array
prefill_position: jax.Array
prefill_page_table: jax.Array
decodes_position: jax.Array
decodes_page_table: jax.Array
[docs] def is_prefill_mode(self) -> bool:
"""
Checks if the current metadata represents a prefill-only step.
Returns:
bool: True if only prefill information is present (decode arrays have empty shape),
False otherwise.
"""
return (
hasattr(self.decodes_position, "shape") and len(self.decodes_position.shape) == 0
)
[docs] def is_decode_mode(self) -> bool:
"""
Creates an initial or placeholder PagedAttentionMetadata object.
(Internal helper method).
Returns:
PagedAttentionMetadata: An instance with scalar placeholder values.
"""
return (
hasattr(self.prefill_position, "shape") and len(self.prefill_position.shape) == 0
)
[docs] @classmethod
def init_empty(cls):
scalar = jax.device_put(
jnp.asarray(1e6, dtype=jnp.int32),
Ns(es.get_incontext_mesh(), Ps()),
)
return PagedAttentionMetadata(
prefill_length=scalar,
prefill_position=scalar,
prefill_page_table=scalar,
decodes_position=scalar,
decodes_page_table=scalar,
)