easydel.layers.caching._specs#
Specification classes for different caching strategies in EasyDeL.
This module defines specification dataclasses that describe the memory layout, size requirements, and behavior of different cache types. These specifications are used to: - Calculate memory requirements before allocation - Configure cache initialization parameters - Optimize memory layout for specific attention patterns - Enable hybrid caching strategies
The specifications follow a hierarchy: - KVCacheSpec: Base specification for all KV cache types - AttentionSpec: Base for attention-based caches - FullAttentionSpec: Standard full attention caching - SlidingWindowSpec: Sliding window attention caching - ChunkedLocalAttentionSpec: Chunked local attention - MambaSpec: State-space model caching
- Key Concepts:
Page Size: Number of tokens per memory page
Type ID: Unique identifier for cache compatibility
Memory Budget: Maximum memory usage calculations
Hybrid Allocation: Mixing different cache types
Example
>>> spec = FullAttentionSpec(
... page_size=128,
... num_kv_heads=8,
... head_size=64,
... dtype=jnp.bfloat16,
... use_mla=False
... )
>>> memory_bytes = spec.max_memory_usage_bytes(max_model_len=2048)
- class easydel.layers.caching._specs.AttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool)[source]#
Bases:
KVCacheSpecBase specification for attention-based cache formats.
Extends KVCacheSpec with attention-specific parameters needed for transformer-based models. This includes head configuration, data types, and optimization flags.
- num_kv_heads#
Number of key-value attention heads. May differ from query heads in multi-query/grouped-query attention.
- Type
int
- head_size#
Dimension of each attention head.
- Type
int
- dtype#
Data type for cache tensors. Common choices: bfloat16, float16, float32.
- Type
- use_mla#
Whether to use Multi-Level Attention optimization. MLA can reduce memory usage by sharing representations.
- Type
bool
- head_size: int#
- num_kv_heads: int#
- property page_size_bytes: int#
Calculate page size for attention cache in bytes.
Computes memory needed for one page of key-value pairs: - Without MLA: stores both keys and values (coef=2) - With MLA: stores combined representation (coef=1)
- Formula:
bytes = coef * page_size * num_kv_heads * head_size * dtype_bytes
- Returns
Size of one attention cache page in bytes.
- Return type
int
- use_mla: bool#
- class easydel.layers.caching._specs.ChunkedLocalAttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, attention_chunk_size: int)[source]#
Bases:
AttentionSpecSpecification for chunked local attention caching.
Optimizes memory usage for models that use local attention patterns where tokens only attend within fixed-size chunks. This significantly reduces memory requirements compared to full attention.
Memory allocation is based on chunk size rather than full sequence length, making it suitable for very long sequences.
- attention_chunk_size#
Size of attention chunks. Tokens can only attend within their chunk boundaries.
- Type
int
- attention_chunk_size: int#
- max_memory_usage_bytes(max_model_len: int, max_num_batched_tokens: int, **kwargs) int[source]#
Calculate maximum memory for chunked attention cache.
Memory is bounded by chunk size plus current batch size, not the full sequence length.
- Parameters
max_model_len (int) – Maximum sequence length (upper bound).
max_num_batched_tokens (int) – Maximum tokens processed per batch.
**kwargs – Additional arguments (unused).
- Returns
- Maximum memory in bytes.
Based on min(chunk_size + batch_tokens, max_model_len).
- Return type
int
- property type_id: str#
Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.
Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)
The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information
- Returns
- A unique string identifier for this cache type.
Format typically: “{strategy}_{params}_{size}”
- Return type
str
Example
“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384
- class easydel.layers.caching._specs.FullAttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, sliding_window: int | None = None, attention_chunk_size: int | None = None)[source]#
Bases:
AttentionSpecSpecification for full attention caching.
Represents standard transformer attention where each token can attend to all previous tokens. This is the most common and memory-intensive cache type.
When hybrid allocation is disabled, this spec can also represent sliding window or chunked attention layers by storing the window/chunk parameters while allocating full cache space. This simplifies memory management at the cost of over-allocation.
- sliding_window#
Optional sliding window size. When set, attention computation uses sliding window but cache allocation remains full-sized. None for standard full attention.
- Type
int | None
- attention_chunk_size#
Optional chunk size for chunked attention. Similar to sliding_window but for chunked patterns. None for standard full attention.
- Type
int | None
Note
Only one of sliding_window or attention_chunk_size should be set. Both being non-None is an error.
- max_memory_usage_bytes(max_model_len: int, **kwargs) int[source]#
Calculate maximum memory for full attention cache.
Memory scales linearly with maximum sequence length since all tokens need to be cached.
- Parameters
max_model_len (int) – Maximum sequence length supported.
**kwargs – Additional arguments (unused).
- Returns
- Maximum memory in bytes.
Formula: ceil(max_model_len / page_size) * page_size_bytes
- Return type
int
- classmethod merge(specs: list[Self]) Self[source]#
Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object.
- classmethod merge_window_sizes(window_sizes: set[int]) int | None[source]#
Merge sliding window sizes from multiple layers.
Ensures all layers in a cache group use the same window size for consistent memory allocation.
- Parameters
window_sizes (set[int]) – Set of window sizes from different layers.
- Returns
The single window size if consistent, None if no windows.
- Return type
int | None
- Raises
ValueError – If layers have different window sizes.
- property type_id: str#
Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.
Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)
The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information
- Returns
- A unique string identifier for this cache type.
Format typically: “{strategy}_{params}_{size}”
- Return type
str
Example
“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384
- class easydel.layers.caching._specs.KVCacheSpec(page_size: int)[source]#
Bases:
objectBase specification for key-value cache formats.
This abstract base class defines the interface that all cache specifications must implement. It provides methods for calculating memory requirements and identifying cache types for compatibility.
The specification pattern allows: - Pre-allocation memory budgeting - Cache type compatibility checking - Hybrid cache configuration - Memory optimization strategies
- page_size#
Number of tokens stored per cache page. Pages are the basic unit of cache allocation and help reduce memory fragmentation.
- Type
int
- Abstract Properties:
type_id: Unique identifier for this cache type page_size_bytes: Size of one page in bytes
- Abstract Methods:
max_memory_usage_bytes: Calculate maximum memory needed merge: Combine multiple specs of the same type
- max_memory_usage_bytes(*args, **kwargs) int[source]#
Calculate maximum memory required for this cache configuration.
Computes the worst-case memory usage for the cache based on the maximum sequence length and other parameters. This is used for memory budgeting and allocation planning.
- Parameters
*args – Implementation-specific arguments.
**kwargs – Implementation-specific keyword arguments. Common kwargs include: - max_model_len: Maximum sequence length - max_num_batched_tokens: Max tokens per batch - max_num_reqs: Maximum concurrent requests
- Returns
Maximum memory usage in bytes.
- Return type
int
Note
Different cache types calculate this differently: - Full attention: O(max_length) - Sliding window: O(window_size) - Chunked: O(chunk_size + batch_size)
- classmethod merge(specs: list[Self]) Self[source]#
Merge multiple cache specifications into a single specification.
Combines specifications from multiple layers that share the same cache type. This is used when multiple layers can share a cache pool for memory efficiency.
The merge process: 1. Validates all specs have compatible type_ids 2. Combines configuration parameters 3. Returns a unified specification
- Parameters
specs (list[Self]) – List of specifications to merge. All must have the same type_id.
- Returns
- A merged specification representing the combined
requirements of all input specifications.
- Return type
Self
- Raises
AssertionError – If specs have incompatible type_ids.
Note
The default implementation returns a copy of the first spec. Subclasses may override to merge specific parameters.
- page_size: int#
- property page_size_bytes: int#
Calculate the memory size of a single cache page in bytes.
This property computes the total memory required to store page_size tokens worth of cache data, accounting for: - Number of heads (key and value) - Head dimensions - Data type size - Any padding or alignment requirements
The calculation typically follows: bytes = page_size * num_heads * head_dim * dtype_bytes * 2 (where 2 accounts for both keys and values)
- Returns
Size of one cache page in bytes.
- Return type
int
Note
Implementations may include padding for memory alignment or hardware-specific optimizations.
- property type_id: str#
Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.
Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)
The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information
- Returns
- A unique string identifier for this cache type.
Format typically: “{strategy}_{params}_{size}”
- Return type
str
Example
“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384
- class easydel.layers.caching._specs.MambaSpec(page_size: int, shapes: tuple[tuple[int, ...], ...], dtype: dtype, page_size_padded: int | None = None)[source]#
Bases:
KVCacheSpecSpecification for Mamba state-space model caching.
Mamba models use state-space representations instead of attention, requiring different cache structures for hidden states and convolutional states.
The cache stores multiple state tensors with different shapes, all packed into a single page-based allocation.
- shapes#
Shapes of state tensors to cache. Each inner tuple defines one state tensor’s shape.
- Type
tuple[tuple[int, …], …]
- dtype#
Data type for state tensors.
- Type
- page_size_padded#
Optional padded page size for alignment. If set, pages are padded to this size.
- Type
int | None
- num_elements#
Total number of elements across all shapes. Calculated automatically in __post_init__.
- Type
int
- max_memory_usage_bytes(*args, **kwargs) int[source]#
Calculate maximum memory for Mamba state cache.
Mamba caches have fixed size per layer regardless of sequence length, as they maintain a constant-size state representation.
- Parameters
*args – Unused (for compatibility).
**kwargs – Unused (for compatibility).
- Returns
Maximum memory in bytes (equals page_size_bytes).
- Return type
int
- property page_size_bytes: int#
Calculate page size for Mamba state cache in bytes.
Computes the memory needed to store all state tensors, optionally with padding for alignment.
- Returns
- Size of one state cache page in bytes.
Uses page_size_padded if specified, otherwise exact size based on num_elements * dtype_size.
- Return type
int
- Raises
AssertionError – If page_size_padded is less than required size.
- shapes: tuple[tuple[int, ...], ...]#
- property type_id: str#
Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.
Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)
The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information
- Returns
- A unique string identifier for this cache type.
Format typically: “{strategy}_{params}_{size}”
- Return type
str
Example
“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384
- class easydel.layers.caching._specs.SlidingWindowSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, sliding_window: int)[source]#
Bases:
AttentionSpecSpecification for sliding window attention caching.
Implements a fixed-size sliding window where tokens can only attend to a limited number of previous tokens. This provides a good balance between memory efficiency and model capability.
The cache maintains a rolling buffer of the most recent tokens, discarding older tokens beyond the window size.
- sliding_window#
Size of the sliding attention window. Each token attends to at most this many previous tokens.
- Type
int
- Constraints:
MLA optimization is not compatible with sliding windows
- max_memory_usage_bytes(max_model_len: int, max_num_batched_tokens: int, **kwargs) int[source]#
Calculate maximum memory for sliding window cache.
Memory is bounded by window size plus current batch, with an extra page for boundary conditions.
- Parameters
max_model_len (int) – Maximum sequence length (upper bound).
max_num_batched_tokens (int) – Maximum tokens processed per batch.
**kwargs – Additional arguments (unused).
- Returns
- Maximum memory in bytes.
Includes extra page for window boundary handling.
- Return type
int
- sliding_window: int#
- property type_id: str#
Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.
Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)
The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information
- Returns
- A unique string identifier for this cache type.
Format typically: “{strategy}_{params}_{size}”
- Return type
str
Example
“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384
- easydel.layers.caching._specs.cdiv(a: int, b: int) int[source]#
Ceiling division: divide a by b and round up.
Computes the ceiling of a/b using integer arithmetic to avoid floating point operations. This is commonly used for calculating the number of pages needed for a given number of tokens.
- Parameters
a (int) – Numerator (e.g., number of tokens)
b (int) – Denominator (e.g., page size)
- Returns
The ceiling of a/b
- Return type
int
Example
>>> cdiv(10, 3) # 10 tokens, 3 per page 4 # Need 4 pages >>> cdiv(9, 3) # 9 tokens, 3 per page 3 # Need 3 pages