easydel.layers.caching.__init__#
- class easydel.layers.caching.__init__.HBMPageManager(metadata: PagedAttentionCacheMetaData)[source]#
Bases:
objectManages the allocation and deallocation of physical HBM pages for the KV cache. It keeps track of available pages.
- _metadata#
Configuration for the paged cache.
- _current_page_index#
Index representing the initial dummy page.
- Type
int
- _available_hbm_pages#
Queue of free HBM page indices.
- Type
queue.SimpleQueue
- alloc_hbm_pages(n: int) list[int][source]#
Allocates a specific number of HBM pages.
- Parameters
n (int) – Number of pages to allocate.
- Returns
Allocated HBM page indices (empty if insufficient pages).
- Return type
list[int]
- alloc_prefill_hbm_pages(prompt_len) list[int][source]#
Allocates the required number of HBM pages for a prompt prefill based on its length.
- Parameters
prompt_len (int) – The length of the prompt (or chunk).
- Returns
List of allocated HBM page indices (empty if insufficient pages).
- Return type
list[int]
- property current_page_index#
Returns the dummy page index (usually 0).
- free_hbm_pages(pages: list[int])[source]#
Returns a list of HBM pages back to the available pool.
- Parameters
pages (list[int]) – HBM page indices to free (ignores dummy page).
- property metadata: PagedAttentionCacheMetaData#
Returns the cache metadata.
- property page_size#
Number of per-token KV cache items per page.
- class easydel.layers.caching.__init__.LightningCache(views: 'tp.List[tp.Optional[LightningCacheView]]')[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(num_hidden_layers: int, metadata: LightningCacheMetaData)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[Optional[LightningCacheView]]#
- class easydel.layers.caching.__init__.LightningCacheMetaData(partition_axis: PartitionAxis, batch_size: Optional[int], num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int])[source]#
Bases:
BaseCacheMetadataMetadata for transformer cache configuration.
- batch_size: Optional[int]#
- classmethod create(partition_axis: PartitionAxis, batch_size: Optional[int] = None, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None) LightningCacheMetaData[source]#
Create a LightningCacheMetaData instance with validation. :returns: LightningCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: Optional[int]#
- key_dim: Optional[int]#
- key_heads: Optional[int]#
- num_heads: Optional[int]#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- value_dim: Optional[int]#
- value_heads: Optional[int]#
- class easydel.layers.caching.__init__.LightningCacheView(key_value: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'LightningCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], kv_sharding: PartitionSpec, quantizer: object, causal_mask: Optional[Union[Array, ndarray, bool, number, bool]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Updates the KV cache with new key/value states and adjusts the attention mask.
Internal helper function used when KV caching is enabled.
- Parameters
query (Array) – Current query states.
key (Array) – Current key states.
value (Array) – Current value states.
attention_mask (Array) – Base attention mask.
causal_mask (tp.Optional[Array], optional) – Causal mask. Defaults to None.
token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment-based masking. Defaults to None.
- Returns
Updated key cache tensor.
Updated value cache tensor.
Updated attention mask reflecting the cached sequence length.
- Return type
tp.Tuple[Array, Array, Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: LightningCacheMetaData, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- layer_index: Optional[int] = None#
- metadata: LightningCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.LightningMetadata[source]#
Bases:
BaseRunTimeMetadata
- class easydel.layers.caching.__init__.Mamba2Cache(views: List[Optional[easydel.layers.caching.mamba2.mamba2_cache.Mamba2CacheView]])[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() Mamba2Cache[source]#
Reset all cache views to their initial state.
- Returns
Reset MambaCache
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#
Update the convolutional state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCache
- update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#
Update the SSM state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_ssm_state – New SSM state to replace the current one
- Returns
Updated MambaCache
- views: List[Optional[Mamba2CacheView]]#
- class easydel.layers.caching.__init__.Mamba2CacheMetaData(partition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#
Bases:
BaseCacheMetadataMetadata for Mamba2 cache configuration.
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(parition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#
Create a Mamba2CacheMetaData instance with validation.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: int#
- intermediate_size: int#
- n_groups: int#
- num_heads: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- state_size: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.Mamba2CacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], seqlen_offset: int, metadata: easydel.layers.caching.mamba2.mamba2_cache.Mamba2CacheMetaData, layer_index: Optional[int] = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(*args, **kwargs)[source]#
Update cache with new states.
- Parameters
*args – Typically includes new tensors
**kwargs – Additional parameters for cache update
- Returns
anything
- Return type
Tuple containing
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- layer_index: Optional[int] = None#
- metadata: Mamba2CacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() Mamba2CacheView[source]#
Reset both conv and ssm states to zeros.
- seqlen_offset: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.Mamba2Metadata[source]#
Bases:
BaseRunTimeMetadata
- class easydel.layers.caching.__init__.MambaCache(views: List[Optional[easydel.layers.caching.mamba.mamba_cache.MambaCacheView]])[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCache[source]#
Reset all cache views to their initial state.
- Returns
Reset MambaCache
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#
Update the convolutional state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCache
- update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#
Update the SSM state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_ssm_state – New SSM state to replace the current one
- Returns
Updated MambaCache
- views: List[Optional[MambaCacheView]]#
- class easydel.layers.caching.__init__.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#
Bases:
BaseCacheMetadataMetadata for Mamba cache configuration.
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#
Create a MambaCacheMetaData instance with validation.
- Parameters
partition_axis –
Partition Axis.
batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size
- Returns
MambaCacheMetaData instance
- Raises
ValueError – If required parameters are invalid
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- intermediate_size: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- ssm_state_size: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(*args, **kwargs)[source]#
Update cache with new states.
- Parameters
*args – Typically includes new tensors
**kwargs – Additional parameters for cache update
- Returns
anything
- Return type
Tuple containing
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- layer_index: Optional[int] = None#
- metadata: MambaCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCacheView[source]#
Reset both conv and ssm states to zeros.
- Returns
Reset MambaCacheView
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#
Update the convolutional state of the cache.
- Parameters
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCacheView
- class easydel.layers.caching.__init__.MambaMetadata[source]#
Bases:
BaseRunTimeMetadata
- class easydel.layers.caching.__init__.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#
Bases:
BaseCacheRepresents the complete Paged Attention KV cache for all layers of a model.
It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.
- views#
A list containing the cache view for each layer in the model.
- Type
tp.List[PagedAttentionCacheView]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the entire PagedAttentionCache for all layers.
Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages.
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
partition_manager (es.PartitionManager) – Manages tensor sharding.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.
- Returns
An initialized cache object containing views for all layers.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[PagedAttentionCacheView]#
- class easydel.layers.caching.__init__.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#
Bases:
BaseCacheMetadataMetadata holding configuration parameters for the Paged Attention KV cache.
This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.
- batch_size#
The maximum number of sequences processed concurrently during decoding.
- Type
int
The total number of transformer layers in the model.
- Type
int
- num_pages_per_layer#
The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.
- Type
int
- num_pages_per_sequence#
The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.
- Type
int
- max_sequences#
The maximum sequence length supported by the cache allocation.
- Type
int
- page_size#
The number of tokens stored per page in the KV cache.
- Type
int
- num_kv_heads#
The number of key/value heads in the attention mechanism.
- Type
int
- kv_head_dim_size#
The dimension size of each key/value head.
- Type
int
- hbm_utilization#
The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.
- Type
float
- batch_size: int#
- classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#
Factory method to create and initialize a PagedAttentionCacheMetaData instance.
Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.
- Parameters
mesh (Mesh) – The JAX device mesh.
batch_size (int) – Maximum concurrent sequences for decode.
num_hidden_layers (int) – Number of transformer layers.
max_sequences (int) – Maximum supported sequence length.
page_size (int) – Number of tokens per cache page.
num_kv_heads (int) – Number of KV heads.
kv_head_dim_size (int) – Dimension of each KV head.
hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).
dtype (jnp.dtype) – Data type used for cache size calculation.
- Returns
An initialized metadata object.
- Return type
- Raises
ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- hbm_utilization: float#
- kv_head_dim_size: int#
- max_sequences: int#
- num_hidden_layers: int#
- num_kv_heads: int#
- num_pages_per_layer: int#
- num_pages_per_sequence: int#
- page_size: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#
Bases:
BaseCacheViewRepresents the view of the Paged Attention KV cache for a single transformer layer.
It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.
- metadata#
The static configuration metadata for the entire paged cache.
- layer_index#
The index of the transformer layer this view corresponds to.
- Type
int
- key_pages#
The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- value_pages#
The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- concatenate_to_cache(*args, **kwargs)[source]#
Concatenation is not applicable for Paged Attention. Raises NotImplementedError.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the PagedAttentionCacheView for a specific layer.
Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
layer_index (int) – The index of the layer this view is for.
partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.
- Returns
An initialized cache view for the specified layer.
- Return type
- layer_index: int#
- metadata: PagedAttentionCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a decode step into the appropriate cache pages.
Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).
- Parameters
key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.
- Returns
Returns self after updating the pages.
- Return type
- write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a prefill step into the appropriate cache pages.
Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.
- Parameters
key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.
- Returns
Returns self after updating the pages.
- Return type
- class easydel.layers.caching.__init__.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#
Bases:
objectRuntime metadata required for performing a Paged Attention computation step.
This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.
- prefill_length#
Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.
- Type
- prefill_position#
JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.
- Type
- prefill_page_table#
JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.
- Type
- decodes_position#
JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.
- Type
- decodes_page_table#
JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.
- Type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_decode_mode() bool[source]#
Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).
- Returns
An instance with scalar placeholder values.
- Return type
- is_prefill_mode() bool[source]#
Checks if the current metadata represents a prefill-only step.
- Returns
- True if only prefill information is present (decode arrays have empty shape),
False otherwise.
- Return type
bool
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.TransformerCache(views: 'tp.List[tp.Optional[TransformerCacheView]]')[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: Optional[dtype] = None, starts: Optional[Array] = None, quantizer: Optional[object] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[Optional[TransformerCacheView]]#
- class easydel.layers.caching.__init__.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#
Bases:
BaseCacheMetadataMetadata for transformer cache configuration.
- batch_size: int#
- classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#
Create a TransformerCacheMetaData instance with validation.
- Parameters
batch_size – Size of the batch.
sequence_length – Length of the sequence.
num_hidden_layers – number of hidden layers.
num_heads – Number of attention heads.
head_dim – Dimension of each head.
key_heads – Number of key heads.
value_heads – Number of value heads.
key_dim – Dimension of keys.
value_dim – Dimension of values.
update_causal_mask – Whether to update causal mask.
create_attention_bias – Whether to create attention bias.
- Returns
TransformerCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- create_attention_bias: bool#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: Optional[int]#
- key_dim: Optional[int]#
- key_heads: Optional[int]#
- num_heads: Optional[int]#
- pad_token_id: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sequence_length: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_causal_mask: bool#
- value_dim: Optional[int]#
- value_heads: Optional[int]#
- class easydel.layers.caching.__init__.TransformerCacheView(key: 'tp.Union[cx.Array, ImplicitArray]', value: 'tp.Union[cx.Array, ImplicitArray]', index: 'tp.Union[cx.Array, ImplicitArray]', starts: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'TransformerCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], quantizer: object, cache_metadata: Optional[TransformerMetadata], attention_mask: Union[Array, ndarray, bool, number], partition_manager: PartitionManager, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
- Parameters
query – Current query states.
key – Current key states to add to the cache.
value – Current value states to add to the cache.
cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask – Base attention mask.
quantizer – Quantizer for the cache.
causal_mask – Optional causal mask.
token_type_ids – Optional token type IDs for segment masking.
- Returns
Updated key cache tensor (functional update).
Updated value cache tensor (functional update).
Final attention mask to be used (either original or calculated).
- Return type
Tuple[Array, Array, Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: Optional[Array] = None, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- property is_empty#
- layer_index: Optional[int] = None#
- metadata: TransformerCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.__init__.TransformerMetadata(postpadded: bool = False, index: int | None = None)[source]#
Bases:
BaseRunTimeMetadataholds optional metadata for attention runtime
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- postpadded: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.