easydel.layers.caching.paged_attention.paged_attention_cache#
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#
Bases:
BaseCacheRepresents the complete Paged Attention KV cache for all layers of a model.
It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.
- views#
A list containing the cache view for each layer in the model.
- Type
tp.List[PagedAttentionCacheView]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the entire PagedAttentionCache for all layers.
Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages.
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
partition_manager (es.PartitionManager) – Manages tensor sharding.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.
- Returns
An initialized cache object containing views for all layers.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[PagedAttentionCacheView]#
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#
Bases:
BaseCacheMetadataMetadata holding configuration parameters for the Paged Attention KV cache.
This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.
- batch_size#
The maximum number of sequences processed concurrently during decoding.
- Type
int
The total number of transformer layers in the model.
- Type
int
- num_pages_per_layer#
The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.
- Type
int
- num_pages_per_sequence#
The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.
- Type
int
- max_sequences#
The maximum sequence length supported by the cache allocation.
- Type
int
- page_size#
The number of tokens stored per page in the KV cache.
- Type
int
- num_kv_heads#
The number of key/value heads in the attention mechanism.
- Type
int
- kv_head_dim_size#
The dimension size of each key/value head.
- Type
int
- hbm_utilization#
The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.
- Type
float
- batch_size: int#
- classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#
Factory method to create and initialize a PagedAttentionCacheMetaData instance.
Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.
- Parameters
mesh (Mesh) – The JAX device mesh.
batch_size (int) – Maximum concurrent sequences for decode.
num_hidden_layers (int) – Number of transformer layers.
max_sequences (int) – Maximum supported sequence length.
page_size (int) – Number of tokens per cache page.
num_kv_heads (int) – Number of KV heads.
kv_head_dim_size (int) – Dimension of each KV head.
hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).
dtype (jnp.dtype) – Data type used for cache size calculation.
- Returns
An initialized metadata object.
- Return type
- Raises
ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- hbm_utilization: float#
- kv_head_dim_size: int#
- max_sequences: int#
- num_hidden_layers: int#
- num_kv_heads: int#
- num_pages_per_layer: int#
- num_pages_per_sequence: int#
- page_size: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#
Bases:
BaseCacheViewRepresents the view of the Paged Attention KV cache for a single transformer layer.
It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.
- metadata#
The static configuration metadata for the entire paged cache.
- layer_index#
The index of the transformer layer this view corresponds to.
- Type
int
- key_pages#
The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- value_pages#
The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- concatenate_to_cache(*args, **kwargs)[source]#
Concatenation is not applicable for Paged Attention. Raises NotImplementedError.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the PagedAttentionCacheView for a specific layer.
Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
layer_index (int) – The index of the layer this view is for.
partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.
- Returns
An initialized cache view for the specified layer.
- Return type
- layer_index: int#
- metadata: PagedAttentionCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a decode step into the appropriate cache pages.
Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).
- Parameters
key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.
- Returns
Returns self after updating the pages.
- Return type
- write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a prefill step into the appropriate cache pages.
Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.
- Parameters
key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.
- Returns
Returns self after updating the pages.
- Return type
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#
Bases:
objectRuntime metadata required for performing a Paged Attention computation step.
This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.
- prefill_length#
Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.
- Type
- prefill_position#
JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.
- Type
- prefill_page_table#
JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.
- Type
- decodes_position#
JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.
- Type
- decodes_page_table#
JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.
- Type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_decode_mode() bool[source]#
Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).
- Returns
An instance with scalar placeholder values.
- Return type
- is_prefill_mode() bool[source]#
Checks if the current metadata represents a prefill-only step.
- Returns
- True if only prefill information is present (decode arrays have empty shape),
False otherwise.
- Return type
bool
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.