easydel.layers.caching.__init__

Contents

easydel.layers.caching.__init__#

class easydel.layers.caching.__init__.HBMPageManager(metadata: PagedAttentionCacheMetaData)[source]#

Bases: object

Manages the allocation and deallocation of physical HBM pages for the KV cache. It keeps track of available pages.

_metadata#

Configuration for the paged cache.

Type

PagedAttentionCacheMetaData

_current_page_index#

Index representing the initial dummy page.

Type

int

_available_hbm_pages#

Queue of free HBM page indices.

Type

queue.SimpleQueue

alloc_hbm_pages(n: int) list[int][source]#

Allocates a specific number of HBM pages.

Parameters

n (int) – Number of pages to allocate.

Returns

Allocated HBM page indices (empty if insufficient pages).

Return type

list[int]

alloc_prefill_hbm_pages(prompt_len) list[int][source]#

Allocates the required number of HBM pages for a prompt prefill based on its length.

Parameters

prompt_len (int) – The length of the prompt (or chunk).

Returns

List of allocated HBM page indices (empty if insufficient pages).

Return type

list[int]

property current_page_index#

Returns the dummy page index (usually 0).

free_hbm_pages(pages: list[int])[source]#

Returns a list of HBM pages back to the available pool.

Parameters

pages (list[int]) – HBM page indices to free (ignores dummy page).

property metadata: PagedAttentionCacheMetaData#

Returns the cache metadata.

property page_size#

Number of per-token KV cache items per page.

class easydel.layers.caching.__init__.LightningCache(views: 'tp.List[tp.Optional[LightningCacheView]]')[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(num_hidden_layers: int, metadata: LightningCacheMetaData)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: List[Optional[LightningCacheView]]#
class easydel.layers.caching.__init__.LightningCacheMetaData(partition_axis: PartitionAxis, batch_size: Optional[int], num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int])[source]#

Bases: BaseCacheMetadata

Metadata for transformer cache configuration.

batch_size: Optional[int]#
classmethod create(partition_axis: PartitionAxis, batch_size: Optional[int] = None, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None) LightningCacheMetaData[source]#

Create a LightningCacheMetaData instance with validation. :returns: LightningCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: Optional[int]#
key_dim: Optional[int]#
key_heads: Optional[int]#
num_heads: Optional[int]#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value_dim: Optional[int]#
value_heads: Optional[int]#
class easydel.layers.caching.__init__.LightningCacheView(key_value: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'LightningCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], kv_sharding: PartitionSpec, quantizer: object, causal_mask: Optional[Union[Array, ndarray, bool, number, bool]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Updates the KV cache with new key/value states and adjusts the attention mask.

Internal helper function used when KV caching is enabled.

Parameters
  • query (Array) – Current query states.

  • key (Array) – Current key states.

  • value (Array) – Current value states.

  • attention_mask (Array) – Base attention mask.

  • causal_mask (tp.Optional[Array], optional) – Causal mask. Defaults to None.

  • token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment-based masking. Defaults to None.

Returns

  • Updated key cache tensor.

  • Updated value cache tensor.

  • Updated attention mask reflecting the cached sequence length.

Return type

tp.Tuple[Array, Array, Array]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: LightningCacheMetaData, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

key_value: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: Optional[int] = None#
metadata: LightningCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.__init__.LightningMetadata[source]#

Bases: BaseRunTimeMetadata

class easydel.layers.caching.__init__.Mamba2Cache(views: List[Optional[easydel.layers.caching.mamba2.mamba2_cache.Mamba2CacheView]])[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() Mamba2Cache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_seq(num)[source]#
update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

views: List[Optional[Mamba2CacheView]]#
class easydel.layers.caching.__init__.Mamba2CacheMetaData(partition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#

Bases: BaseCacheMetadata

Metadata for Mamba2 cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(parition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#

Create a Mamba2CacheMetaData instance with validation.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: int#
intermediate_size: int#
n_groups: int#
num_heads: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state_size: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.__init__.Mamba2CacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], seqlen_offset: int, metadata: easydel.layers.caching.mamba2.mamba2_cache.Mamba2CacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(*args, **kwargs)[source]#

Update cache with new states.

Parameters
  • *args – Typically includes new tensors

  • **kwargs – Additional parameters for cache update

Returns

  • anything

Return type

Tuple containing

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

layer_index: Optional[int] = None#
metadata: Mamba2CacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() Mamba2CacheView[source]#

Reset both conv and ssm states to zeros.

seqlen_offset: int#
ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#

Update the convolutional state of the cache.

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#

Update the SSM state of the cache.

class easydel.layers.caching.__init__.Mamba2Metadata[source]#

Bases: BaseRunTimeMetadata

class easydel.layers.caching.__init__.MambaCache(views: List[Optional[easydel.layers.caching.mamba.mamba_cache.MambaCacheView]])[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

views: List[Optional[MambaCacheView]]#
class easydel.layers.caching.__init__.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#

Bases: BaseCacheMetadata

Metadata for Mamba cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#

Create a MambaCacheMetaData instance with validation.

Parameters

partition_axis –

Partition Axis.

batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size

Returns

MambaCacheMetaData instance

Raises

ValueError – If required parameters are invalid

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

intermediate_size: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

ssm_state_size: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.__init__.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(*args, **kwargs)[source]#

Update cache with new states.

Parameters
  • *args – Typically includes new tensors

  • **kwargs – Additional parameters for cache update

Returns

  • anything

Return type

Tuple containing

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

layer_index: Optional[int] = None#
metadata: MambaCacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCacheView[source]#

Reset both conv and ssm states to zeros.

Returns

Reset MambaCacheView

ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the convolutional state of the cache.

Parameters
  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCacheView

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the SSM state of the cache.

Parameters

new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCacheView

class easydel.layers.caching.__init__.MambaMetadata[source]#

Bases: BaseRunTimeMetadata

class easydel.layers.caching.__init__.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#

Bases: BaseCache

Represents the complete Paged Attention KV cache for all layers of a model.

It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.

views#

A list containing the cache view for each layer in the model.

Type

tp.List[PagedAttentionCacheView]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the entire PagedAttentionCache for all layers.

Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages.

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • partition_manager (es.PartitionManager) – Manages tensor sharding.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.

Returns

An initialized cache object containing views for all layers.

Return type

PagedAttentionCache

init_empty(*args, **kwargs)[source]#

Not typically used for PagedAttentionCache; returns None.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: List[PagedAttentionCacheView]#
class easydel.layers.caching.__init__.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#

Bases: BaseCacheMetadata

Metadata holding configuration parameters for the Paged Attention KV cache.

This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.

batch_size#

The maximum number of sequences processed concurrently during decoding.

Type

int

num_hidden_layers#

The total number of transformer layers in the model.

Type

int

num_pages_per_layer#

The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.

Type

int

num_pages_per_sequence#

The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.

Type

int

max_sequences#

The maximum sequence length supported by the cache allocation.

Type

int

page_size#

The number of tokens stored per page in the KV cache.

Type

int

num_kv_heads#

The number of key/value heads in the attention mechanism.

Type

int

kv_head_dim_size#

The dimension size of each key/value head.

Type

int

hbm_utilization#

The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.

Type

float

batch_size: int#
classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#

Factory method to create and initialize a PagedAttentionCacheMetaData instance.

Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • batch_size (int) – Maximum concurrent sequences for decode.

  • num_hidden_layers (int) – Number of transformer layers.

  • max_sequences (int) – Maximum supported sequence length.

  • page_size (int) – Number of tokens per cache page.

  • num_kv_heads (int) – Number of KV heads.

  • kv_head_dim_size (int) – Dimension of each KV head.

  • hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).

  • dtype (jnp.dtype) – Data type used for cache size calculation.

Returns

An initialized metadata object.

Return type

PagedAttentionCacheMetaData

Raises

ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hbm_utilization: float#
kv_head_dim_size: int#
max_sequences: int#
num_hidden_layers: int#
num_kv_heads: int#
num_pages_per_layer: int#
num_pages_per_sequence: int#
page_size: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.__init__.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#

Bases: BaseCacheView

Represents the view of the Paged Attention KV cache for a single transformer layer.

It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.

metadata#

The static configuration metadata for the entire paged cache.

Type

PagedAttentionCacheMetaData

layer_index#

The index of the transformer layer this view corresponds to.

Type

int

key_pages#

The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

value_pages#

The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

concatenate_to_cache(*args, **kwargs)[source]#

Concatenation is not applicable for Paged Attention. Raises NotImplementedError.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the PagedAttentionCacheView for a specific layer.

Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • layer_index (int) – The index of the layer this view is for.

  • partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.

Returns

An initialized cache view for the specified layer.

Return type

PagedAttentionCacheView

key_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: int#
metadata: PagedAttentionCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a decode step into the appropriate cache pages.

Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).

Parameters
  • key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a prefill step into the appropriate cache pages.

Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.

Parameters
  • key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

class easydel.layers.caching.__init__.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#

Bases: object

Runtime metadata required for performing a Paged Attention computation step.

This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.

prefill_length#

Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.

Type

jax.Array

prefill_position#

JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.

Type

jax.Array

prefill_page_table#

JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.

Type

jax.Array

decodes_position#

JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table#

JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table: Array#
decodes_position: Array#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty()[source]#
is_decode_mode() bool[source]#

Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).

Returns

An instance with scalar placeholder values.

Return type

PagedAttentionMetadata

is_prefill_mode() bool[source]#

Checks if the current metadata represents a prefill-only step.

Returns

True if only prefill information is present (decode arrays have empty shape),

False otherwise.

Return type

bool

prefill_length: Array#
prefill_page_table: Array#
prefill_position: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.__init__.TransformerCache(views: 'tp.List[tp.Optional[TransformerCacheView]]')[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod from_pure(pure, metadata)[source]#
classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: Optional[dtype] = None, starts: Optional[Array] = None, quantizer: Optional[object] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#
insert_index(index, slot: int, partition_manager: PartitionManager)[source]#
insert_starts(starts, slot: int, partition_manager: PartitionManager)[source]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

to_pure()[source]#
views: List[Optional[TransformerCacheView]]#
class easydel.layers.caching.__init__.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#

Bases: BaseCacheMetadata

Metadata for transformer cache configuration.

batch_size: int#
classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#

Create a TransformerCacheMetaData instance with validation.

Parameters
  • batch_size – Size of the batch.

  • sequence_length – Length of the sequence.

  • num_hidden_layers – number of hidden layers.

  • num_heads – Number of attention heads.

  • head_dim – Dimension of each head.

  • key_heads – Number of key heads.

  • value_heads – Number of value heads.

  • key_dim – Dimension of keys.

  • value_dim – Dimension of values.

  • update_causal_mask – Whether to update causal mask.

  • create_attention_bias – Whether to create attention bias.

Returns

TransformerCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

create_attention_bias: bool#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: Optional[int]#
key_dim: Optional[int]#
key_heads: Optional[int]#
num_heads: Optional[int]#
num_hidden_layers: int#
pad_token_id: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequence_length: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_causal_mask: bool#
value_dim: Optional[int]#
value_heads: Optional[int]#
class easydel.layers.caching.__init__.TransformerCacheView(key: 'tp.Union[cx.Array, ImplicitArray]', value: 'tp.Union[cx.Array, ImplicitArray]', index: 'tp.Union[cx.Array, ImplicitArray]', starts: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'TransformerCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], quantizer: object, cache_metadata: Optional[TransformerMetadata], attention_mask: Union[Array, ndarray, bool, number], partition_manager: PartitionManager, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.

Parameters
  • query – Current query states.

  • key – Current key states to add to the cache.

  • value – Current value states to add to the cache.

  • cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.

  • attention_mask – Base attention mask.

  • quantizer – Quantizer for the cache.

  • causal_mask – Optional causal mask.

  • token_type_ids – Optional token type IDs for segment masking.

Returns

  • Updated key cache tensor (functional update).

  • Updated value cache tensor (functional update).

  • Final attention mask to be used (either original or calculated).

Return type

Tuple[Array, Array, Array]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

index: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: Optional[Array] = None, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

property is_empty#
key: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: Optional[int] = None#
metadata: TransformerCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

starts: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value: Union[Array, ndarray, bool, number, ImplicitArray]#
class easydel.layers.caching.__init__.TransformerMetadata(postpadded: bool = False, index: int | None = None)[source]#

Bases: BaseRunTimeMetadata

holds optional metadata for attention runtime

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

index: int | None = None#
postpadded: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.