Source code for easydel.layers.caching.paged_attention.paged_attention_cache

from __future__ import annotations

import math
import queue
import typing as tp

import chex as cx
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from eformer.jaximus import ImplicitArray
from eformer import escale as es
from eformer.pytree import auto_pytree
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from .._abstracts import (
	BaseCache,
	BaseCacheMetadata,
	BaseCacheView,
)

if tp.TYPE_CHECKING:
	from easydel.layers.quantization.quantizers import EasyQuantizer
else:
	EasyQuantizer = object


[docs]@auto_pytree class PagedAttentionCacheMetaData(BaseCacheMetadata): """Metadata for Paged Attention KV cache configuration.""" partition_axis: es.PartitionAxis dtype: jnp.dtype batch_size: int num_hidden_layers: int num_pages_per_layer: int num_pages_per_sequence: int max_sequences: int page_size: int num_kv_heads: int kv_head_dim_size: int hbm_utilization: float hbm_bytes: float @staticmethod def _usable_hbm(hbm_utilization: float, mesh: Mesh) -> int: per_device_memory_stats = jax.devices()[0].memory_stats() limit = per_device_memory_stats["bytes_reservable_limit"] used = per_device_memory_stats["bytes_in_use"] return (int(limit * hbm_utilization) - used) * mesh.devices.size
[docs] @classmethod def create( cls, mesh: Mesh, partition_axis: es.PartitionAxis, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: jnp.dtype = jnp.bfloat16, ) -> "PagedAttentionCacheMetaData": if batch_size <= 0: raise ValueError("`batch_size` must be positive") if num_hidden_layers <= 0: raise ValueError("`num_hidden_layers` must be positive") if max_sequences <= 0: raise ValueError("`max_sequences` must be positive") if page_size <= 0: raise ValueError("`page_size` must be positive") if num_kv_heads <= 0: raise ValueError("`num_kv_heads` must be positive") if kv_head_dim_size <= 0: raise ValueError("`kv_head_dim_size` must be positive") if not (0.0 < hbm_utilization < 1.0): raise ValueError("`hbm_utilization` must be positive float value in range 0~1") hbm_bytes = cls._usable_hbm(hbm_utilization, mesh) item_size = np.dtype(dtype).itemsize per_kv_bytes = num_kv_heads * kv_head_dim_size * item_size * 2 num_pages_per_layer = (hbm_bytes // num_hidden_layers) // (page_size * per_kv_bytes) num_pages_per_sequence = math.ceil(max_sequences / page_size) return cls( partition_axis=partition_axis, dtype=dtype, batch_size=batch_size, num_hidden_layers=num_hidden_layers, num_pages_per_layer=num_pages_per_layer, num_pages_per_sequence=num_pages_per_sequence, max_sequences=max_sequences, page_size=page_size, num_kv_heads=num_kv_heads, kv_head_dim_size=kv_head_dim_size, hbm_utilization=hbm_utilization, hbm_bytes=hbm_bytes, )
[docs]@auto_pytree class PagedAttentionCacheView(BaseCacheView): """Minimal view for a layer within the PagedAttentionCache.""" metadata: PagedAttentionCacheMetaData layer_index: int key_pages: tp.Union[cx.Array, ImplicitArray] value_pages: tp.Union[cx.Array, ImplicitArray] kv_pages_sharding: NamedSharding
[docs] @classmethod def init( cls, mesh: Mesh, metadata: PagedAttentionCacheMetaData, layer_index: int, quantizer: tp.Optional["EasyQuantizer"] = None, kv_pages_sharding: tp.Optional[PartitionSpec] = None, ): from easydel.infra.etils import EasyDeLQuantizationMethods from easydel.layers.quantization.quantizers import EasyQuantizer quantizer = quantizer or EasyQuantizer(EasyDeLQuantizationMethods.NONE) default_ps = PartitionSpec(metadata.partition_axis.head_axis, None, None, None) kv_pages_sharding = kv_pages_sharding or default_ps dtype = metadata.dtype kv_pages_sharding = NamedSharding(mesh=mesh, spec=kv_pages_sharding) kv_pages_shape = ( metadata.num_kv_heads, metadata.num_pages_per_layer, metadata.page_size, metadata.kv_head_dim_size, ) with jax.named_scope("easydel-paged-attention-cache-init"): key_pages = jnp.zeros( shape=kv_pages_shape, dtype=dtype, device=kv_pages_sharding, ) value_pages = jnp.zeros( shape=kv_pages_shape, dtype=dtype, device=kv_pages_sharding, ) key_pages = quantizer(key_pages) value_pages = quantizer(value_pages) return cls( metadata=metadata, layer_index=layer_index, key_pages=key_pages, value_pages=value_pages, kv_pages_sharding=kv_pages_sharding, )
[docs] def concatenate_to_cache(self, *args, **kwargs): raise NotImplementedError()
[docs] def write_prefill_to_cache( self, key: cx.Array, value: cx.Array, metadata: PagedAttentionMetadata, ): padded_prefill_len = key.shape[0] num_kv_heads_per_device = self.key_pages.shape[0] page_size = self.key_pages.shape[2] num_pages = padded_prefill_len // page_size num_pages = jnp.where(num_pages < 1, 1, num_pages) num_active_pages, reminder = jnp.divmod(metadata.prefill_length, page_size) num_active_pages += jnp.where(reminder > 0, 1, 0) head_dim = self.key_pages.shape[-1] key = ( key.transpose((1, 0, 2)) .reshape((num_kv_heads_per_device, -1, page_size, head_dim)) .astype(self.key_pages.dtype) ) value = ( value.transpose((1, 0, 2)) .reshape((num_kv_heads_per_device, -1, page_size, head_dim)) .astype(self.value_pages.dtype) ) def update_cond(carry): _, idx = carry return idx < num_active_pages def per_page_update(carry): (kp, vp), idx = carry page_k = key[:, idx, :, :][:, None, :, :] page_v = value[:, idx, :, :][:, None, :, :] mapped_idx = metadata.prefill_page_table[idx] kp = jax.lax.dynamic_update_slice_in_dim( kp, page_k, mapped_idx, axis=1, ) vp = jax.lax.dynamic_update_slice_in_dim( vp, page_v, mapped_idx, axis=1, ) idx += 1 return (kp, vp), idx idx = 0 (self.key_pages, self.value_pages), idx = jax.lax.while_loop( update_cond, per_page_update, ((self.key_pages, self.value_pages), idx), )
[docs] def write_generate_to_cache( self, key: cx.Array, value: cx.Array, metadata: PagedAttentionMetadata, ): key = key.transpose((1, 0, 2)) value = value.transpose((1, 0, 2)) key = key.astype(self.key_pages.dtype) value = value.astype(self.value_pages.dtype) num_tokens = key.shape[1] num_kv_heads_per_device, num_pages, page_size, head_dim = self.key_pages.shape page_idx, offset = jnp.divmod(metadata.generate_pos, page_size) page_to_update = metadata.generate_page_table[jnp.arange(0, num_tokens), page_idx] mapped_page_to_update = page_to_update * page_size + offset mapped_page_to_update = jnp.tile(mapped_page_to_update, num_kv_heads_per_device) kv_heads_axis_stride = ( jnp.repeat(jnp.arange(0, num_kv_heads_per_device), num_tokens) * num_pages * page_size ) mapped_page_to_update = kv_heads_axis_stride + mapped_page_to_update key = key.reshape(-1, head_dim) value = value.reshape(-1, head_dim) self.key_pages = self.key_pages.reshape(-1, head_dim) self.value_pages = self.value_pages.reshape(-1, head_dim) self.key_pages = self.key_pages.at[mapped_page_to_update, :].set(key) self.value_pages = self.value_pages.at[mapped_page_to_update, :].set(value) self.key_pages = self.key_pages.reshape( num_kv_heads_per_device, num_pages, page_size, head_dim, ) self.value_pages = self.value_pages.reshape( num_kv_heads_per_device, num_pages, page_size, head_dim, )
def __repr__(self): return f"{self.__class__.__name__}(layer_index={self.layer_index}, kv_shape={self.key_pages.shape})" __str__ = __repr__
[docs]@auto_pytree class PagedAttentionCache(BaseCache): views: tp.List[PagedAttentionCacheView]
[docs] @classmethod def init_cache( cls, mesh: Mesh, metadata: PagedAttentionCacheMetaData, quantizer: tp.Optional["EasyQuantizer"] = None, kv_pages_sharding: tp.Optional[PartitionSpec] = None, ): views = [ PagedAttentionCacheView.init( mesh=mesh, metadata=metadata, layer_index=i, quantizer=quantizer, kv_pages_sharding=kv_pages_sharding, ) for i in range(metadata.num_hidden_layers) ] return cls(views=views)
[docs] def init_empty(self, *args, **kwargs): return None
def __repr__(self): idx = self.views[-1] try: k_shape = idx.key_pages.shape v_shape = idx.value_pages.shape except AttributeError: k_shape = "Uninitialized" v_shape = "Uninitialized" return ( f"{self.__class__.__name__}(\n" f" key_pages={k_shape},\n" f" value_pages={v_shape},\n" f" num_layers={len(self.views)},\n" f" kv_pages_sharding={idx.kv_pages_sharding}\n" ")" ) __str__ = __repr__
[docs]@auto_pytree class PagedAttentionMetadata: prefill_length: jax.Array prefill_pos: jax.Array prefill_page_table: jax.Array generate_pos: jax.Array generate_page_table: jax.Array @classmethod def _init_decode_state(cls): return PagedAttentionMetadata( prefill_length=jnp.asarray(1e7, dtype=jnp.int32), prefill_pos=jnp.asarray(1e7, dtype=jnp.int32), prefill_page_table=jnp.asarray(1e7, dtype=jnp.int32), generate_pos=jnp.asarray(1e7, dtype=jnp.int32), generate_page_table=jnp.asarray(1e7, dtype=jnp.int32), )
[docs]class PagedAttentionCacheManager: """Logical KV Cache Manager""" def __init__(self, metadata: PagedAttentionCacheMetaData): """Initializes the PagedAttentionCacheViewManager.""" self._metadata = metadata self._current_page_index = 0 self._available_hbm_pages = queue.SimpleQueue() for p in range(1, metadata.num_pages_per_layer): self._available_hbm_pages.put_nowait(p) @property def page_size(self): """Returns the page size in the number of per-token kv cache items.""" return self._metadata.page_size @property def current_page_index(self): """Returns the dummy page index (0).""" return self._current_page_index
[docs] def alloc_prefill_hbm_pages(self, prompt_len) -> list[int]: """Allocates HBM pages for prompt prefill.""" n = math.ceil(prompt_len / self._metadata.page_size) return self.alloc_hbm_pages(n)
[docs] def alloc_hbm_pages(self, n: int) -> list[int]: """Allocates `n` HBM pages.""" if 0 < n <= self._available_hbm_pages.qsize(): return [self._available_hbm_pages.get(block=True) for _ in range(n)] else: return []
[docs] def free_hbm_pages(self, pages: list[int]): """Frees the given HBM pages.""" for p in pages: if p != self._current_page_index: self._available_hbm_pages.put_nowait(p)