easydel.layers.caching.ragged_page.cache

Contents

easydel.layers.caching.ragged_page.cache#

class easydel.layers.caching.ragged_page.cache.RaggedPagesCache(views: list[easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView])[source]#

Bases: BaseCache

Represents the complete Paged Attention KV cache for all layers of a model.

It holds a list of RaggedPagesCacheView objects, one for each layer. It inherits from BaseCache.

views#

A list containing the cache view for each layer in the model.

Type

tp.List[RaggedPagesCacheView]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(mesh: Mesh, metadata: RaggedPagesCacheMetaData, partition_manager: PartitionManager, quantizer: object | None = None) RaggedPagesCache[source]#

Initializes the entire RaggedPagesCache for all layers.

Creates a list of RaggedPagesCacheView instances, one for each layer specified in the metadata, by calling RaggedPagesCacheView.init for each layer.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages.

  • metadata (RaggedPagesCacheMetaData) – Static configuration for the cache.

  • partition_manager (es.PartitionManager) – Manages tensor sharding.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.

Returns

An initialized cache object containing views for all layers.

Return type

RaggedPagesCache

init_empty(*args, **kwargs) None[source]#

Not typically used for RaggedPagesCache; returns None.

property metadata: easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData | None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: list[easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView]#
class easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData(num_hidden_layers: int, max_model_length: int, num_kv_heads: int, k_headdim: int, v_headdim: int, hbm_utilization: float = 0.9, page_size: int = 128, num_pages: int = -1, max_num_pages_per_req: int = -1, num_slices_per_kv_cache_update_page: int = -1, max_num_tokens: int = -1, max_num_reqs: int = -1, version: Union[str, Literal['v3', 'v2']] = 'v3', _kvdtype_str: str = 'bf16')[source]#

Bases: BaseCacheMetadata

Metadata holding configuration parameters for the Paged Attention KV cache.

This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.

classmethod create(mesh: Mesh, partition_manager: PartitionManager, kvdtype: dtype, num_hidden_layers: int, num_kv_heads: int, max_model_length: int, kv_head_dim_size: int | None = None, k_headdim: int | None = None, v_headdim: int | None = None, hbm_utilization: float = 0.9, page_size: int = 128, version: Literal['v3', 'v2'] = 'v3') RaggedPagesCacheMetaData[source]#

Factory method to create and validate a metadata instance.

This method serves as the primary constructor for metadata objects, providing a centralized location for parameter validation and initialization logic. It should be used instead of direct instantiation to ensure all metadata objects are properly validated.

The factory pattern allows for: - Complex initialization logic beyond simple assignment - Parameter validation before object creation - Derived parameter calculation - Consistent error handling across implementations

Parameters
  • *args – Positional arguments for metadata creation. Implementation-specific parameters.

  • **kwargs – Keyword arguments for metadata creation. Implementation-specific parameters.

Returns

A validated instance of the concrete metadata

implementation. The returned object is immutable and ready for use in cache initialization.

Return type

BaseCacheMetadata

Raises
  • ValueError – If any validation checks fail. Common validations include: - Positive integer checks for dimensions and sizes - Range checks for ratios and percentages - Consistency checks between related parameters - Resource availability checks

  • TypeError – If required parameters are missing or have incorrect types.

Example

>>> metadata = TransformerCacheMetaData.create(
...     batch_size=4,
...     sequence_length=1024,
...     num_hidden_layers=12,
...     num_heads=8,
...     head_dim=64
... )
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_max_num_seqs() int[source]#
get_padded_num_slices(num_tokens: int | None = None, max_num_reqs: int | None = None) int[source]#
get_shape_and_axes()[source]#
hbm_utilization: float = 0.9#
property is_v2#
property is_v3#
k_headdim: int#
property kv_head_packing: int#
property kvdtype: dtype#
max_model_length: int#
max_num_pages_per_req: int = -1#
max_num_reqs: int = -1#
max_num_tokens: int = -1#
num_hidden_layers: int#
num_kv_heads: int#
num_pages: int = -1#
num_slices_per_kv_cache_update_page: int = -1#
page_size: int = 128#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

property storage_head_dim: int#
property storage_num_combined_kv_heads: int#
property storage_num_kv_groups: int#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

v_headdim: int#
version: Union[str, Literal['v3', 'v2']] = 'v3'#
class easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView(metadata: ~easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData, layer_index: int, kv_pages: jaxtyping.Float[Array, 'num_pages page_size storage_groups packing head_dim'] | jaxtyping.Float[Array, 'num_pages page_size kv_head_combined head_dim'] | eformer.jaximus._imus.ImplicitArray, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <factory>)[source]#

Bases: BaseCacheView

Represents the view of the Paged Attention KV cache for a single transformer layer.

It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.

metadata#

The static configuration metadata for the entire paged cache.

Type

RaggedPagesCacheMetaData

layer_index#

The index of the transformer layer this view corresponds to.

Type

int

kv_pages#

The tensor holding all key value pages for this layer. Shape: (num_pages, page_size, aligned_kv_groups, packing, aligned_head_dim). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

concatenate_to_cache(key: Float[Array, 'batch seq_len num_key_heads head_dim'], value: Float[Array, 'batch seq_len num_value_heads head_dim'], cache_metadata: RaggedPagesMetadata) RaggedPagesCacheView[source]#

Update the cache with new computed states.

This is the primary method for cache updates during inference. It takes newly computed states (keys, values, hidden states, etc.) and incorporates them into the cache, returning updated tensors and any additional information needed for computation.

The update process typically: 1. Validates input shapes and dtypes 2. Determines update position in cache 3. Applies quantization if configured 4. Updates cache tensors functionally 5. Adjusts masks and indices 6. Returns updated state for next computation

Parameters
  • *args – Positional arguments vary by cache type but commonly include: - key: New key states (for attention caches) - value: New value states (for attention caches) - hidden_states: New hidden states (for SSM caches) - positions: Sequence positions for update

  • **kwargs – Keyword arguments vary by cache type but commonly include: - attention_mask: Mask for valid positions - cache_metadata: Runtime metadata for update - quantizer: Quantization function - causal_mask: Causal attention pattern - mode: Prefill vs generation mode

Returns

Return type varies by implementation but typically includes:
  • Updated cache tensors (functional return)

  • Modified attention masks

  • Updated view instance

  • Additional computation results

Common return patterns: - Transformer: (key_cache, value_cache, mask, updated_view) - Mamba: (updated_view,) - Paged: (updated_view,)

Return type

tp.Any

Note

This method should be functional, returning new tensors rather than modifying existing ones in-place. This ensures compatibility with JAX’s functional programming model.

Example

>>> key_cache, value_cache, mask, new_view = view.concatenate_to_cache(
...     query=query_states,
...     key=key_states,
...     value=value_states,
...     attention_mask=mask
... )
flattened_kv_pages() Float[Array, 'num_pages page_size num_kv_heads_x2 head_dim'][source]#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(mesh: Mesh, metadata: RaggedPagesCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: object | None = None) RaggedPagesCacheView[source]#

Initializes the RaggedPagesCacheView for a specific layer.

Allocates the kv_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).

  • metadata (RaggedPagesCacheMetaData) – Static configuration for the cache.

  • layer_index (int) – The index of the layer this view is for.

  • partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.

Returns

An initialized cache view for the specified layer.

Return type

RaggedPagesCacheView

property key_pages: Float[Array, 'num_pages page_size num_kv_heads head_dim']#
kv_pages: jaxtyping.Float[Array, 'num_pages page_size storage_groups packing head_dim'] | jaxtyping.Float[Array, 'num_pages page_size kv_head_combined head_dim'] | eformer.jaximus._imus.ImplicitArray#
layer_index: int#
metadata: RaggedPagesCacheMetaData#
partition_manager: PartitionManager#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

property value_pages: Float[Array, 'num_pages page_size num_kv_heads head_dim']#
class easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata(pages_tables: "Int[Array, 'max_num_reqs max_pages']", context_lens: "Int[Array, 'max_num_reqs']", query_start_loc: "Int[Array, 'max_num_reqs_plus_1']", num_seqs: "Int[Array, 'max_num_reqs']", slot_mapping: "Int[Array, 'num_tokens'] | None" = None, position_ids: "Int[Array, 'num_tokens'] | None" = None, request_distribution: "Int[Array, '3'] | None" = None, num_kv_update_slices: "Int[Array, '1'] | None" = None, version: "str | tp.Literal['v3', 'v2']" = 'v3', num_slices_per_kv_cache_update_page: 'int | None' = <factory>, page_size: 'int' = 128, prefill_chunk_size: 'int' = 512)[source]#

Bases: object

context_lens: Int[Array, 'max_num_reqs']#
classmethod create_empty(num_tokens: int, max_num_reqs: int, max_pages: int, page_size: int = 128, version: Literal['v3', 'v2'] = 'v3') RaggedPagesMetadata[source]#

Create empty metadata with proper shapes.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

num_kv_update_slices: jaxtyping.Int[Array, '1'] | None = None#
num_seqs: Int[Array, 'max_num_reqs']#
num_slices_per_kv_cache_update_page: int | None#
page_size: int = 128#
pages_tables: Int[Array, 'max_num_reqs max_pages']#
position_ids: jaxtyping.Int[Array, 'num_tokens'] | None = None#
prefill_chunk_size: int = 512#
query_start_loc: Int[Array, 'max_num_reqs_plus_1']#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

request_distribution: jaxtyping.Int[Array, '3'] | None = None#
slot_mapping: jaxtyping.Int[Array, 'num_tokens'] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

version: Union[str, Literal['v3', 'v2']] = 'v3'#
easydel.layers.caching.ragged_page.cache.align_to_multiple(value: int, multiple: int) int[source]#
easydel.layers.caching.ragged_page.cache.cdiv(a: int, b: int) int[source]#
easydel.layers.caching.ragged_page.cache.get_dtype_packing(dtype: dtype) int[source]#
easydel.layers.caching.ragged_page.cache.get_num_slices_per_kv_cache_update_page(page_size_bytes: int) int[source]#
easydel.layers.caching.ragged_page.cache.get_page_size_bytes(page_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: dtype) int[source]#

Returns the size in bytes of one page of the KV cache.

easydel.layers.caching.ragged_page.cache.per_device_hbm_budget_bytes(util: float = 0.9, mode: str = 'free', safety_margin: int = 268435456) int[source]#
easydel.layers.caching.ragged_page.cache.previous_power_of_2(n: int) int[source]#