easydel.layers.caching.paged_attention.paged_attention_cache

Contents

easydel.layers.caching.paged_attention.paged_attention_cache#

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#

Bases: BaseCache

Represents the complete Paged Attention KV cache for all layers of a model.

It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.

views#

A list containing the cache view for each layer in the model.

Type

tp.List[PagedAttentionCacheView]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the entire PagedAttentionCache for all layers.

Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages.

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • partition_manager (es.PartitionManager) – Manages tensor sharding.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.

Returns

An initialized cache object containing views for all layers.

Return type

PagedAttentionCache

init_empty(*args, **kwargs)[source]#

Not typically used for PagedAttentionCache; returns None.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: List[PagedAttentionCacheView]#
class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#

Bases: BaseCacheMetadata

Metadata holding configuration parameters for the Paged Attention KV cache.

This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.

batch_size#

The maximum number of sequences processed concurrently during decoding.

Type

int

num_hidden_layers#

The total number of transformer layers in the model.

Type

int

num_pages_per_layer#

The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.

Type

int

num_pages_per_sequence#

The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.

Type

int

max_sequences#

The maximum sequence length supported by the cache allocation.

Type

int

page_size#

The number of tokens stored per page in the KV cache.

Type

int

num_kv_heads#

The number of key/value heads in the attention mechanism.

Type

int

kv_head_dim_size#

The dimension size of each key/value head.

Type

int

hbm_utilization#

The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.

Type

float

batch_size: int#
classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#

Factory method to create and initialize a PagedAttentionCacheMetaData instance.

Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • batch_size (int) – Maximum concurrent sequences for decode.

  • num_hidden_layers (int) – Number of transformer layers.

  • max_sequences (int) – Maximum supported sequence length.

  • page_size (int) – Number of tokens per cache page.

  • num_kv_heads (int) – Number of KV heads.

  • kv_head_dim_size (int) – Dimension of each KV head.

  • hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).

  • dtype (jnp.dtype) – Data type used for cache size calculation.

Returns

An initialized metadata object.

Return type

PagedAttentionCacheMetaData

Raises

ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hbm_utilization: float#
kv_head_dim_size: int#
max_sequences: int#
num_hidden_layers: int#
num_kv_heads: int#
num_pages_per_layer: int#
num_pages_per_sequence: int#
page_size: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#

Bases: BaseCacheView

Represents the view of the Paged Attention KV cache for a single transformer layer.

It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.

metadata#

The static configuration metadata for the entire paged cache.

Type

PagedAttentionCacheMetaData

layer_index#

The index of the transformer layer this view corresponds to.

Type

int

key_pages#

The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

value_pages#

The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

concatenate_to_cache(*args, **kwargs)[source]#

Concatenation is not applicable for Paged Attention. Raises NotImplementedError.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the PagedAttentionCacheView for a specific layer.

Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • layer_index (int) – The index of the layer this view is for.

  • partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.

Returns

An initialized cache view for the specified layer.

Return type

PagedAttentionCacheView

key_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: int#
metadata: PagedAttentionCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a decode step into the appropriate cache pages.

Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).

Parameters
  • key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a prefill step into the appropriate cache pages.

Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.

Parameters
  • key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#

Bases: object

Runtime metadata required for performing a Paged Attention computation step.

This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.

prefill_length#

Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.

Type

jax.Array

prefill_position#

JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.

Type

jax.Array

prefill_page_table#

JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.

Type

jax.Array

decodes_position#

JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table#

JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table: Array#
decodes_position: Array#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty()[source]#
is_decode_mode() bool[source]#

Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).

Returns

An instance with scalar placeholder values.

Return type

PagedAttentionMetadata

is_prefill_mode() bool[source]#

Checks if the current metadata represents a prefill-only step.

Returns

True if only prefill information is present (decode arrays have empty shape),

False otherwise.

Return type

bool

prefill_length: Array#
prefill_page_table: Array#
prefill_position: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.