easydel.layers.caching.paged_attention.paged_attention_cache

Contents

easydel.layers.caching.paged_attention.paged_attention_cache#

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCache(views: 'tp.List[PagedAttentionCacheView]')[source]#

Bases: BaseCache

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

classmethod init_cache(mesh: Mesh, metadata: PagedAttentionCacheMetaData, quantizer: Optional[object] = None, kv_pages_sharding: Optional[PartitionSpec] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

init_empty(*args, **kwargs)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

views: List[PagedAttentionCacheView]#
class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheManager(metadata: PagedAttentionCacheMetaData)[source]#

Bases: object

Logical KV Cache Manager

alloc_hbm_pages(n: int) list[int][source]#

Allocates n HBM pages.

alloc_prefill_hbm_pages(prompt_len) list[int][source]#

Allocates HBM pages for prompt prefill.

property current_page_index#

Returns the dummy page index (0).

free_hbm_pages(pages: list[int])[source]#

Frees the given HBM pages.

property page_size#

Returns the page size in the number of per-token kv cache items.

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData(partition_axis: PartitionAxis, dtype: dtype, batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, hbm_bytes: float)[source]#

Bases: BaseCacheMetadata

Metadata for Paged Attention KV cache configuration.

batch_size: int#
classmethod create(mesh: ~jax._src.mesh.Mesh, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#

Factory method to create validated metadata instance.

Parameters
  • *args – Positional arguments for metadata creation

  • **kwargs – Keyword arguments for metadata creation

Returns

Instance of concrete metadata implementation

Raises

ValueError – If any validation checks fail

dtype: dtype#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

hbm_bytes: float#
hbm_utilization: float#
kv_head_dim_size: int#
max_sequences: int#
num_hidden_layers: int#
num_kv_heads: int#
num_pages_per_layer: int#
num_pages_per_sequence: int#
page_size: int#
partition_axis: PartitionAxis#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray], kv_pages_sharding: NamedSharding)[source]#

Bases: BaseCacheView

Minimal view for a layer within the PagedAttentionCache.

concatenate_to_cache(*args, **kwargs)[source]#

Update cache with new states.

Parameters
  • *args – Typically includes new tensors

  • **kwargs – Additional parameters for cache update

Returns

  • anything

Return type

Tuple containing

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

classmethod init(mesh: Mesh, metadata: PagedAttentionCacheMetaData, layer_index: int, quantizer: Optional[object] = None, kv_pages_sharding: Optional[PartitionSpec] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

key_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
kv_pages_sharding: NamedSharding#
layer_index: int#
metadata: PagedAttentionCacheMetaData#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

value_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
write_generate_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
class easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionMetadata(prefill_length: 'jax.Array', prefill_pos: 'jax.Array', prefill_page_table: 'jax.Array', generate_pos: 'jax.Array', generate_page_table: 'jax.Array')[source]#

Bases: object

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

generate_page_table: Array#
generate_pos: Array#
prefill_length: Array#
prefill_page_table: Array#
prefill_pos: Array#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.