easydel.layers.caching.ragged_page.cache#
- class easydel.layers.caching.ragged_page.cache.RaggedPagesCache(views: list[easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView])[source]#
Bases:
BaseCacheRepresents the complete Paged Attention KV cache for all layers of a model.
It holds a list of RaggedPagesCacheView objects, one for each layer. It inherits from BaseCache.
- views#
A list containing the cache view for each layer in the model.
- Type
tp.List[RaggedPagesCacheView]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, metadata: RaggedPagesCacheMetaData, partition_manager: PartitionManager, quantizer: object | None = None) RaggedPagesCache[source]#
Initializes the entire RaggedPagesCache for all layers.
Creates a list of RaggedPagesCacheView instances, one for each layer specified in the metadata, by calling RaggedPagesCacheView.init for each layer.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages.
metadata (RaggedPagesCacheMetaData) – Static configuration for the cache.
partition_manager (es.PartitionManager) – Manages tensor sharding.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.
- Returns
An initialized cache object containing views for all layers.
- Return type
- property metadata: easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData | None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData(num_hidden_layers: int, max_model_length: int, num_kv_heads: int, k_headdim: int, v_headdim: int, hbm_utilization: float = 0.9, page_size: int = 128, num_pages: int = -1, max_num_pages_per_req: int = -1, num_slices_per_kv_cache_update_page: int = -1, max_num_tokens: int = -1, max_num_reqs: int = -1, version: Union[str, Literal['v3', 'v2']] = 'v3', _kvdtype_str: str = 'bf16')[source]#
Bases:
BaseCacheMetadataMetadata holding configuration parameters for the Paged Attention KV cache.
This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.
- classmethod create(mesh: Mesh, partition_manager: PartitionManager, kvdtype: dtype, num_hidden_layers: int, num_kv_heads: int, max_model_length: int, kv_head_dim_size: int | None = None, k_headdim: int | None = None, v_headdim: int | None = None, hbm_utilization: float = 0.9, page_size: int = 128, version: Literal['v3', 'v2'] = 'v3') RaggedPagesCacheMetaData[source]#
Factory method to create and validate a metadata instance.
This method serves as the primary constructor for metadata objects, providing a centralized location for parameter validation and initialization logic. It should be used instead of direct instantiation to ensure all metadata objects are properly validated.
The factory pattern allows for: - Complex initialization logic beyond simple assignment - Parameter validation before object creation - Derived parameter calculation - Consistent error handling across implementations
- Parameters
*args – Positional arguments for metadata creation. Implementation-specific parameters.
**kwargs – Keyword arguments for metadata creation. Implementation-specific parameters.
- Returns
- A validated instance of the concrete metadata
implementation. The returned object is immutable and ready for use in cache initialization.
- Return type
- Raises
ValueError – If any validation checks fail. Common validations include: - Positive integer checks for dimensions and sizes - Range checks for ratios and percentages - Consistency checks between related parameters - Resource availability checks
TypeError – If required parameters are missing or have incorrect types.
Example
>>> metadata = TransformerCacheMetaData.create( ... batch_size=4, ... sequence_length=1024, ... num_hidden_layers=12, ... num_heads=8, ... head_dim=64 ... )
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- hbm_utilization: float = 0.9#
- property is_v2#
- property is_v3#
- k_headdim: int#
- property kv_head_packing: int#
- max_model_length: int#
- max_num_pages_per_req: int = -1#
- max_num_reqs: int = -1#
- max_num_tokens: int = -1#
- num_kv_heads: int#
- num_pages: int = -1#
- num_slices_per_kv_cache_update_page: int = -1#
- page_size: int = 128#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- property storage_head_dim: int#
- property storage_num_combined_kv_heads: int#
- property storage_num_kv_groups: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- v_headdim: int#
- version: Union[str, Literal['v3', 'v2']] = 'v3'#
- class easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView(metadata: ~easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData, layer_index: int, kv_pages: jaxtyping.Float[Array, 'num_pages page_size storage_groups packing head_dim'] | jaxtyping.Float[Array, 'num_pages page_size kv_head_combined head_dim'] | eformer.jaximus._imus.ImplicitArray, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <factory>)[source]#
Bases:
BaseCacheViewRepresents the view of the Paged Attention KV cache for a single transformer layer.
It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.
- metadata#
The static configuration metadata for the entire paged cache.
- layer_index#
The index of the transformer layer this view corresponds to.
- Type
int
- kv_pages#
The tensor holding all key value pages for this layer. Shape: (num_pages, page_size, aligned_kv_groups, packing, aligned_head_dim). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- concatenate_to_cache(key: Float[Array, 'batch seq_len num_key_heads head_dim'], value: Float[Array, 'batch seq_len num_value_heads head_dim'], cache_metadata: RaggedPagesMetadata) RaggedPagesCacheView[source]#
Update the cache with new computed states.
This is the primary method for cache updates during inference. It takes newly computed states (keys, values, hidden states, etc.) and incorporates them into the cache, returning updated tensors and any additional information needed for computation.
The update process typically: 1. Validates input shapes and dtypes 2. Determines update position in cache 3. Applies quantization if configured 4. Updates cache tensors functionally 5. Adjusts masks and indices 6. Returns updated state for next computation
- Parameters
*args – Positional arguments vary by cache type but commonly include: - key: New key states (for attention caches) - value: New value states (for attention caches) - hidden_states: New hidden states (for SSM caches) - positions: Sequence positions for update
**kwargs – Keyword arguments vary by cache type but commonly include: - attention_mask: Mask for valid positions - cache_metadata: Runtime metadata for update - quantizer: Quantization function - causal_mask: Causal attention pattern - mode: Prefill vs generation mode
- Returns
- Return type varies by implementation but typically includes:
Updated cache tensors (functional return)
Modified attention masks
Updated view instance
Additional computation results
Common return patterns: - Transformer: (key_cache, value_cache, mask, updated_view) - Mamba: (updated_view,) - Paged: (updated_view,)
- Return type
tp.Any
Note
This method should be functional, returning new tensors rather than modifying existing ones in-place. This ensures compatibility with JAX’s functional programming model.
Example
>>> key_cache, value_cache, mask, new_view = view.concatenate_to_cache( ... query=query_states, ... key=key_states, ... value=value_states, ... attention_mask=mask ... )
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, metadata: RaggedPagesCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: object | None = None) RaggedPagesCacheView[source]#
Initializes the RaggedPagesCacheView for a specific layer.
Allocates the kv_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).
metadata (RaggedPagesCacheMetaData) – Static configuration for the cache.
layer_index (int) – The index of the layer this view is for.
partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.
- Returns
An initialized cache view for the specified layer.
- Return type
- property key_pages: Float[Array, 'num_pages page_size num_kv_heads head_dim']#
- kv_pages: jaxtyping.Float[Array, 'num_pages page_size storage_groups packing head_dim'] | jaxtyping.Float[Array, 'num_pages page_size kv_head_combined head_dim'] | eformer.jaximus._imus.ImplicitArray#
- layer_index: int#
- metadata: RaggedPagesCacheMetaData#
- partition_manager: PartitionManager#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- property value_pages: Float[Array, 'num_pages page_size num_kv_heads head_dim']#
- class easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata(pages_tables: "Int[Array, 'max_num_reqs max_pages']", context_lens: "Int[Array, 'max_num_reqs']", query_start_loc: "Int[Array, 'max_num_reqs_plus_1']", num_seqs: "Int[Array, 'max_num_reqs']", slot_mapping: "Int[Array, 'num_tokens'] | None" = None, position_ids: "Int[Array, 'num_tokens'] | None" = None, request_distribution: "Int[Array, '3'] | None" = None, num_kv_update_slices: "Int[Array, '1'] | None" = None, version: "str | tp.Literal['v3', 'v2']" = 'v3', num_slices_per_kv_cache_update_page: 'int | None' = <factory>, page_size: 'int' = 128, prefill_chunk_size: 'int' = 512)[source]#
Bases:
object- context_lens: Int[Array, 'max_num_reqs']#
- classmethod create_empty(num_tokens: int, max_num_reqs: int, max_pages: int, page_size: int = 128, version: Literal['v3', 'v2'] = 'v3') RaggedPagesMetadata[source]#
Create empty metadata with proper shapes.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- num_seqs: Int[Array, 'max_num_reqs']#
- page_size: int = 128#
- pages_tables: Int[Array, 'max_num_reqs max_pages']#
- prefill_chunk_size: int = 512#
- query_start_loc: Int[Array, 'max_num_reqs_plus_1']#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- version: Union[str, Literal['v3', 'v2']] = 'v3'#
- easydel.layers.caching.ragged_page.cache.get_num_slices_per_kv_cache_update_page(page_size_bytes: int) int[source]#
- easydel.layers.caching.ragged_page.cache.get_page_size_bytes(page_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: dtype) int[source]#
Returns the size in bytes of one page of the KV cache.