easydel.layers.caching.paged_attention.paged_attention_cache#
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCache(views: 'tp.List[PagedAttentionCacheView]')[source]#
Bases:
BaseCache- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- classmethod init_cache(mesh: Mesh, metadata: PagedAttentionCacheMetaData, quantizer: Optional[object] = None, kv_pages_sharding: Optional[PartitionSpec] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- init_empty(*args, **kwargs)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- views: List[PagedAttentionCacheView]#
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheManager(metadata: PagedAttentionCacheMetaData)[source]#
Bases:
objectLogical KV Cache Manager
- property current_page_index#
Returns the dummy page index (0).
- property page_size#
Returns the page size in the number of per-token kv cache items.
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData(partition_axis: PartitionAxis, dtype: dtype, batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, hbm_bytes: float)[source]#
Bases:
BaseCacheMetadataMetadata for Paged Attention KV cache configuration.
- batch_size: int#
- classmethod create(mesh: ~jax._src.mesh.Mesh, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#
Factory method to create validated metadata instance.
- Parameters
*args – Positional arguments for metadata creation
**kwargs – Keyword arguments for metadata creation
- Returns
Instance of concrete metadata implementation
- Raises
ValueError – If any validation checks fail
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- hbm_bytes: float#
- hbm_utilization: float#
- kv_head_dim_size: int#
- max_sequences: int#
- num_kv_heads: int#
- num_pages_per_layer: int#
- num_pages_per_sequence: int#
- page_size: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray], kv_pages_sharding: NamedSharding)[source]#
Bases:
BaseCacheViewMinimal view for a layer within the PagedAttentionCache.
- concatenate_to_cache(*args, **kwargs)[source]#
Update cache with new states.
- Parameters
*args – Typically includes new tensors
**kwargs – Additional parameters for cache update
- Returns
anything
- Return type
Tuple containing
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- classmethod init(mesh: Mesh, metadata: PagedAttentionCacheMetaData, layer_index: int, quantizer: Optional[object] = None, kv_pages_sharding: Optional[PartitionSpec] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- kv_pages_sharding: NamedSharding#
- layer_index: int#
- metadata: PagedAttentionCacheMetaData#
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionMetadata(prefill_length: 'jax.Array', prefill_pos: 'jax.Array', prefill_page_table: 'jax.Array', generate_pos: 'jax.Array', generate_page_table: 'jax.Array')[source]#
Bases:
object- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.