# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specification classes for different caching strategies in EasyDeL.
This module defines specification dataclasses that describe the memory layout,
size requirements, and behavior of different cache types. These specifications
are used to:
- Calculate memory requirements before allocation
- Configure cache initialization parameters
- Optimize memory layout for specific attention patterns
- Enable hybrid caching strategies
The specifications follow a hierarchy:
- KVCacheSpec: Base specification for all KV cache types
- AttentionSpec: Base for attention-based caches
- FullAttentionSpec: Standard full attention caching
- SlidingWindowSpec: Sliding window attention caching
- ChunkedLocalAttentionSpec: Chunked local attention
- MambaSpec: State-space model caching
Key Concepts:
- Page Size: Number of tokens per memory page
- Type ID: Unique identifier for cache compatibility
- Memory Budget: Maximum memory usage calculations
- Hybrid Allocation: Mixing different cache types
Example:
>>> spec = FullAttentionSpec(
... page_size=128,
... num_kv_heads=8,
... head_size=64,
... dtype=jnp.bfloat16,
... use_mla=False
... )
>>> memory_bytes = spec.max_memory_usage_bytes(max_model_len=2048)
"""
import copy
from dataclasses import dataclass
from math import prod
from typing import Self
import jax
[docs]def cdiv(a: int, b: int) -> int:
"""Ceiling division: divide a by b and round up.
Computes the ceiling of a/b using integer arithmetic to avoid
floating point operations. This is commonly used for calculating
the number of pages needed for a given number of tokens.
Args:
a (int): Numerator (e.g., number of tokens)
b (int): Denominator (e.g., page size)
Returns:
int: The ceiling of a/b
Example:
>>> cdiv(10, 3) # 10 tokens, 3 per page
4 # Need 4 pages
>>> cdiv(9, 3) # 9 tokens, 3 per page
3 # Need 3 pages
"""
return (a + b - 1) // b
[docs]@dataclass
class KVCacheSpec:
"""Base specification for key-value cache formats.
This abstract base class defines the interface that all cache
specifications must implement. It provides methods for calculating
memory requirements and identifying cache types for compatibility.
The specification pattern allows:
- Pre-allocation memory budgeting
- Cache type compatibility checking
- Hybrid cache configuration
- Memory optimization strategies
Attributes:
page_size (int): Number of tokens stored per cache page.
Pages are the basic unit of cache allocation and help
reduce memory fragmentation.
Abstract Properties:
type_id: Unique identifier for this cache type
page_size_bytes: Size of one page in bytes
Abstract Methods:
max_memory_usage_bytes: Calculate maximum memory needed
merge: Combine multiple specs of the same type
"""
page_size: int
@property
def type_id(self) -> str:
"""Unique identifier for this cache specification type.
The type ID is used to determine cache compatibility when mixing
different cache types in a model. Caches with the same type_id
can share memory pools and be managed together.
Different type IDs should be returned for:
- Different attention patterns (full vs sliding window)
- Different cache sizes per token (varying head counts)
- Different memory layouts (paged vs continuous)
The ID typically encodes:
- Cache strategy name
- Key configuration parameters
- Memory layout information
Returns:
str: A unique string identifier for this cache type.
Format typically: "{strategy}_{params}_{size}"
Example:
"full_attention_128_16384" for full attention with
page_size=128 and page_size_bytes=16384
"""
raise NotImplementedError
@property
def page_size_bytes(self) -> int:
"""Calculate the memory size of a single cache page in bytes.
This property computes the total memory required to store
`page_size` tokens worth of cache data, accounting for:
- Number of heads (key and value)
- Head dimensions
- Data type size
- Any padding or alignment requirements
The calculation typically follows:
bytes = page_size * num_heads * head_dim * dtype_bytes * 2
(where 2 accounts for both keys and values)
Returns:
int: Size of one cache page in bytes.
Note:
Implementations may include padding for memory alignment
or hardware-specific optimizations.
"""
raise NotImplementedError
[docs] def max_memory_usage_bytes(self, *args, **kwargs) -> int:
"""Calculate maximum memory required for this cache configuration.
Computes the worst-case memory usage for the cache based on
the maximum sequence length and other parameters. This is used
for memory budgeting and allocation planning.
Args:
*args: Implementation-specific arguments.
**kwargs: Implementation-specific keyword arguments.
Common kwargs include:
- max_model_len: Maximum sequence length
- max_num_batched_tokens: Max tokens per batch
- max_num_reqs: Maximum concurrent requests
Returns:
int: Maximum memory usage in bytes.
Note:
Different cache types calculate this differently:
- Full attention: O(max_length)
- Sliding window: O(window_size)
- Chunked: O(chunk_size + batch_size)
"""
raise NotImplementedError
[docs] @classmethod
def merge(cls, specs: list[Self]) -> Self:
"""Merge multiple cache specifications into a single specification.
Combines specifications from multiple layers that share the same
cache type. This is used when multiple layers can share a cache
pool for memory efficiency.
The merge process:
1. Validates all specs have compatible type_ids
2. Combines configuration parameters
3. Returns a unified specification
Args:
specs (list[Self]): List of specifications to merge.
All must have the same type_id.
Returns:
Self: A merged specification representing the combined
requirements of all input specifications.
Raises:
AssertionError: If specs have incompatible type_ids.
Note:
The default implementation returns a copy of the first spec.
Subclasses may override to merge specific parameters.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same type_id."
)
return copy.deepcopy(specs[0])
[docs]@dataclass
class AttentionSpec(KVCacheSpec):
"""Base specification for attention-based cache formats.
Extends KVCacheSpec with attention-specific parameters needed
for transformer-based models. This includes head configuration,
data types, and optimization flags.
Attributes:
num_kv_heads (int): Number of key-value attention heads.
May differ from query heads in multi-query/grouped-query attention.
head_size (int): Dimension of each attention head.
dtype (jax.numpy.dtype): Data type for cache tensors.
Common choices: bfloat16, float16, float32.
use_mla (bool): Whether to use Multi-Level Attention optimization.
MLA can reduce memory usage by sharing representations.
"""
num_kv_heads: int
head_size: int
dtype: jax.numpy.dtype
use_mla: bool
@property
def page_size_bytes(self) -> int:
"""Calculate page size for attention cache in bytes.
Computes memory needed for one page of key-value pairs:
- Without MLA: stores both keys and values (coef=2)
- With MLA: stores combined representation (coef=1)
Formula:
bytes = coef * page_size * num_kv_heads * head_size * dtype_bytes
Returns:
int: Size of one attention cache page in bytes.
"""
coef = 1 if self.use_mla else 2
return coef * self.page_size * self.num_kv_heads * self.head_size * (jax.numpy.finfo(self.dtype).bits // 8)
[docs]@dataclass
class FullAttentionSpec(AttentionSpec):
"""Specification for full attention caching.
Represents standard transformer attention where each token can
attend to all previous tokens. This is the most common and
memory-intensive cache type.
When hybrid allocation is disabled, this spec can also represent
sliding window or chunked attention layers by storing the window/chunk
parameters while allocating full cache space. This simplifies memory
management at the cost of over-allocation.
Attributes:
sliding_window (int | None): Optional sliding window size.
When set, attention computation uses sliding window but
cache allocation remains full-sized. None for standard full attention.
attention_chunk_size (int | None): Optional chunk size for
chunked attention. Similar to sliding_window but for
chunked patterns. None for standard full attention.
Note:
Only one of sliding_window or attention_chunk_size should be set.
Both being non-None is an error.
"""
sliding_window: int | None = None
attention_chunk_size: int | None = None
@property
def type_id(self) -> str:
return f"full_attention_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(self, max_model_len: int, **kwargs) -> int:
"""Calculate maximum memory for full attention cache.
Memory scales linearly with maximum sequence length since
all tokens need to be cached.
Args:
max_model_len (int): Maximum sequence length supported.
**kwargs: Additional arguments (unused).
Returns:
int: Maximum memory in bytes.
Formula: ceil(max_model_len / page_size) * page_size_bytes
"""
return cdiv(max_model_len, self.page_size) * self.page_size_bytes
[docs] @classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
"""Merge sliding window sizes from multiple layers.
Ensures all layers in a cache group use the same window size
for consistent memory allocation.
Args:
window_sizes (set[int]): Set of window sizes from different layers.
Returns:
int | None: The single window size if consistent, None if no windows.
Raises:
ValueError: If layers have different window sizes.
"""
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError("All attention layers in the same KV cache group must have the same window size.")
[docs] @classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs if spec.attention_chunk_size is not None)
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
merged_spec.attention_chunk_size = cls.merge_window_sizes(attention_chunk_size)
assert (merged_spec.sliding_window is not None) + (merged_spec.attention_chunk_size is not None) <= 1, (
"Model with both sliding window layers and chunked local attention layers is not supported."
)
return merged_spec
[docs]@dataclass
class ChunkedLocalAttentionSpec(AttentionSpec):
"""Specification for chunked local attention caching.
Optimizes memory usage for models that use local attention patterns
where tokens only attend within fixed-size chunks. This significantly
reduces memory requirements compared to full attention.
Memory allocation is based on chunk size rather than full sequence
length, making it suitable for very long sequences.
Attributes:
attention_chunk_size (int): Size of attention chunks.
Tokens can only attend within their chunk boundaries.
"""
attention_chunk_size: int
@property
def type_id(self) -> str:
return f"local_attention_{self.attention_chunk_size}_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(
self,
max_model_len: int,
max_num_batched_tokens: int,
**kwargs,
) -> int:
"""Calculate maximum memory for chunked attention cache.
Memory is bounded by chunk size plus current batch size,
not the full sequence length.
Args:
max_model_len (int): Maximum sequence length (upper bound).
max_num_batched_tokens (int): Maximum tokens processed per batch.
**kwargs: Additional arguments (unused).
Returns:
int: Maximum memory in bytes.
Based on min(chunk_size + batch_tokens, max_model_len).
"""
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, max_model_len)
return cdiv(num_tokens, self.page_size) * self.page_size_bytes
[docs]@dataclass
class SlidingWindowSpec(AttentionSpec):
"""Specification for sliding window attention caching.
Implements a fixed-size sliding window where tokens can only
attend to a limited number of previous tokens. This provides
a good balance between memory efficiency and model capability.
The cache maintains a rolling buffer of the most recent tokens,
discarding older tokens beyond the window size.
Attributes:
sliding_window (int): Size of the sliding attention window.
Each token attends to at most this many previous tokens.
Constraints:
- MLA optimization is not compatible with sliding windows
"""
sliding_window: int
def __post_init__(self):
"""Validate sliding window configuration.
Raises:
AssertionError: If MLA is enabled (not supported).
"""
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(
self,
max_model_len: int,
max_num_batched_tokens: int,
**kwargs,
) -> int:
"""Calculate maximum memory for sliding window cache.
Memory is bounded by window size plus current batch, with
an extra page for boundary conditions.
Args:
max_model_len (int): Maximum sequence length (upper bound).
max_num_batched_tokens (int): Maximum tokens processed per batch.
**kwargs: Additional arguments (unused).
Returns:
int: Maximum memory in bytes.
Includes extra page for window boundary handling.
"""
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len)
return (cdiv(num_tokens, self.page_size) + 1) * self.page_size_bytes
[docs]@dataclass
class MambaSpec(KVCacheSpec):
"""Specification for Mamba state-space model caching.
Mamba models use state-space representations instead of attention,
requiring different cache structures for hidden states and
convolutional states.
The cache stores multiple state tensors with different shapes,
all packed into a single page-based allocation.
Attributes:
shapes (tuple[tuple[int, ...], ...]): Shapes of state tensors
to cache. Each inner tuple defines one state tensor's shape.
dtype (jax.numpy.dtype): Data type for state tensors.
page_size_padded (int | None): Optional padded page size for
alignment. If set, pages are padded to this size.
num_elements (int): Total number of elements across all shapes.
Calculated automatically in __post_init__.
"""
shapes: tuple[tuple[int, ...], ...]
dtype: jax.numpy.dtype
page_size_padded: int | None = None
def __post_init__(self):
"""Calculate total elements from shapes.
Sets num_elements to the sum of products of each shape.
"""
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
@property
def page_size_bytes(self) -> int:
"""Calculate page size for Mamba state cache in bytes.
Computes the memory needed to store all state tensors,
optionally with padding for alignment.
Returns:
int: Size of one state cache page in bytes.
Uses page_size_padded if specified, otherwise
exact size based on num_elements * dtype_size.
Raises:
AssertionError: If page_size_padded is less than required size.
"""
page_size = self.num_elements * (jax.numpy.finfo(self.dtype).bits // 8)
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded
return page_size
[docs] def max_memory_usage_bytes(self, *args, **kwargs) -> int:
"""Calculate maximum memory for Mamba state cache.
Mamba caches have fixed size per layer regardless of sequence
length, as they maintain a constant-size state representation.
Args:
*args: Unused (for compatibility).
**kwargs: Unused (for compatibility).
Returns:
int: Maximum memory in bytes (equals page_size_bytes).
"""
return self.page_size_bytes