easydel.layers.caching._specs#

Specification classes for different caching strategies in EasyDeL.

This module defines specification dataclasses that describe the memory layout, size requirements, and behavior of different cache types. These specifications are used to: - Calculate memory requirements before allocation - Configure cache initialization parameters - Optimize memory layout for specific attention patterns - Enable hybrid caching strategies

The specifications follow a hierarchy: - KVCacheSpec: Base specification for all KV cache types - AttentionSpec: Base for attention-based caches - FullAttentionSpec: Standard full attention caching - SlidingWindowSpec: Sliding window attention caching - ChunkedLocalAttentionSpec: Chunked local attention - MambaSpec: State-space model caching

Key Concepts:
  • Page Size: Number of tokens per memory page

  • Type ID: Unique identifier for cache compatibility

  • Memory Budget: Maximum memory usage calculations

  • Hybrid Allocation: Mixing different cache types

Example

>>> spec = FullAttentionSpec(
...     page_size=128,
...     num_kv_heads=8,
...     head_size=64,
...     dtype=jnp.bfloat16,
...     use_mla=False
... )
>>> memory_bytes = spec.max_memory_usage_bytes(max_model_len=2048)
class easydel.layers.caching._specs.AttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool)[source]#

Bases: KVCacheSpec

Base specification for attention-based cache formats.

Extends KVCacheSpec with attention-specific parameters needed for transformer-based models. This includes head configuration, data types, and optimization flags.

num_kv_heads#

Number of key-value attention heads. May differ from query heads in multi-query/grouped-query attention.

Type

int

head_size#

Dimension of each attention head.

Type

int

dtype#

Data type for cache tensors. Common choices: bfloat16, float16, float32.

Type

jax.numpy.dtype

use_mla#

Whether to use Multi-Level Attention optimization. MLA can reduce memory usage by sharing representations.

Type

bool

dtype: dtype#
head_size: int#
num_kv_heads: int#
property page_size_bytes: int#

Calculate page size for attention cache in bytes.

Computes memory needed for one page of key-value pairs: - Without MLA: stores both keys and values (coef=2) - With MLA: stores combined representation (coef=1)

Formula:

bytes = coef * page_size * num_kv_heads * head_size * dtype_bytes

Returns

Size of one attention cache page in bytes.

Return type

int

use_mla: bool#
class easydel.layers.caching._specs.ChunkedLocalAttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, attention_chunk_size: int)[source]#

Bases: AttentionSpec

Specification for chunked local attention caching.

Optimizes memory usage for models that use local attention patterns where tokens only attend within fixed-size chunks. This significantly reduces memory requirements compared to full attention.

Memory allocation is based on chunk size rather than full sequence length, making it suitable for very long sequences.

attention_chunk_size#

Size of attention chunks. Tokens can only attend within their chunk boundaries.

Type

int

attention_chunk_size: int#
max_memory_usage_bytes(max_model_len: int, max_num_batched_tokens: int, **kwargs) int[source]#

Calculate maximum memory for chunked attention cache.

Memory is bounded by chunk size plus current batch size, not the full sequence length.

Parameters
  • max_model_len (int) – Maximum sequence length (upper bound).

  • max_num_batched_tokens (int) – Maximum tokens processed per batch.

  • **kwargs – Additional arguments (unused).

Returns

Maximum memory in bytes.

Based on min(chunk_size + batch_tokens, max_model_len).

Return type

int

property type_id: str#

Unique identifier for this cache specification type.

The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.

Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)

The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information

Returns

A unique string identifier for this cache type.

Format typically: “{strategy}_{params}_{size}”

Return type

str

Example

“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384

class easydel.layers.caching._specs.FullAttentionSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, sliding_window: int | None = None, attention_chunk_size: int | None = None)[source]#

Bases: AttentionSpec

Specification for full attention caching.

Represents standard transformer attention where each token can attend to all previous tokens. This is the most common and memory-intensive cache type.

When hybrid allocation is disabled, this spec can also represent sliding window or chunked attention layers by storing the window/chunk parameters while allocating full cache space. This simplifies memory management at the cost of over-allocation.

sliding_window#

Optional sliding window size. When set, attention computation uses sliding window but cache allocation remains full-sized. None for standard full attention.

Type

int | None

attention_chunk_size#

Optional chunk size for chunked attention. Similar to sliding_window but for chunked patterns. None for standard full attention.

Type

int | None

Note

Only one of sliding_window or attention_chunk_size should be set. Both being non-None is an error.

attention_chunk_size: int | None = None#
max_memory_usage_bytes(max_model_len: int, **kwargs) int[source]#

Calculate maximum memory for full attention cache.

Memory scales linearly with maximum sequence length since all tokens need to be cached.

Parameters
  • max_model_len (int) – Maximum sequence length supported.

  • **kwargs – Additional arguments (unused).

Returns

Maximum memory in bytes.

Formula: ceil(max_model_len / page_size) * page_size_bytes

Return type

int

classmethod merge(specs: list[Self]) Self[source]#

Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object.

classmethod merge_window_sizes(window_sizes: set[int]) int | None[source]#

Merge sliding window sizes from multiple layers.

Ensures all layers in a cache group use the same window size for consistent memory allocation.

Parameters

window_sizes (set[int]) – Set of window sizes from different layers.

Returns

The single window size if consistent, None if no windows.

Return type

int | None

Raises

ValueError – If layers have different window sizes.

sliding_window: int | None = None#
property type_id: str#

Unique identifier for this cache specification type.

The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.

Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)

The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information

Returns

A unique string identifier for this cache type.

Format typically: “{strategy}_{params}_{size}”

Return type

str

Example

“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384

class easydel.layers.caching._specs.KVCacheSpec(page_size: int)[source]#

Bases: object

Base specification for key-value cache formats.

This abstract base class defines the interface that all cache specifications must implement. It provides methods for calculating memory requirements and identifying cache types for compatibility.

The specification pattern allows: - Pre-allocation memory budgeting - Cache type compatibility checking - Hybrid cache configuration - Memory optimization strategies

page_size#

Number of tokens stored per cache page. Pages are the basic unit of cache allocation and help reduce memory fragmentation.

Type

int

Abstract Properties:

type_id: Unique identifier for this cache type page_size_bytes: Size of one page in bytes

Abstract Methods:

max_memory_usage_bytes: Calculate maximum memory needed merge: Combine multiple specs of the same type

max_memory_usage_bytes(*args, **kwargs) int[source]#

Calculate maximum memory required for this cache configuration.

Computes the worst-case memory usage for the cache based on the maximum sequence length and other parameters. This is used for memory budgeting and allocation planning.

Parameters
  • *args – Implementation-specific arguments.

  • **kwargs – Implementation-specific keyword arguments. Common kwargs include: - max_model_len: Maximum sequence length - max_num_batched_tokens: Max tokens per batch - max_num_reqs: Maximum concurrent requests

Returns

Maximum memory usage in bytes.

Return type

int

Note

Different cache types calculate this differently: - Full attention: O(max_length) - Sliding window: O(window_size) - Chunked: O(chunk_size + batch_size)

classmethod merge(specs: list[Self]) Self[source]#

Merge multiple cache specifications into a single specification.

Combines specifications from multiple layers that share the same cache type. This is used when multiple layers can share a cache pool for memory efficiency.

The merge process: 1. Validates all specs have compatible type_ids 2. Combines configuration parameters 3. Returns a unified specification

Parameters

specs (list[Self]) – List of specifications to merge. All must have the same type_id.

Returns

A merged specification representing the combined

requirements of all input specifications.

Return type

Self

Raises

AssertionError – If specs have incompatible type_ids.

Note

The default implementation returns a copy of the first spec. Subclasses may override to merge specific parameters.

page_size: int#
property page_size_bytes: int#

Calculate the memory size of a single cache page in bytes.

This property computes the total memory required to store page_size tokens worth of cache data, accounting for: - Number of heads (key and value) - Head dimensions - Data type size - Any padding or alignment requirements

The calculation typically follows: bytes = page_size * num_heads * head_dim * dtype_bytes * 2 (where 2 accounts for both keys and values)

Returns

Size of one cache page in bytes.

Return type

int

Note

Implementations may include padding for memory alignment or hardware-specific optimizations.

property type_id: str#

Unique identifier for this cache specification type.

The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.

Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)

The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information

Returns

A unique string identifier for this cache type.

Format typically: “{strategy}_{params}_{size}”

Return type

str

Example

“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384

class easydel.layers.caching._specs.MambaSpec(page_size: int, shapes: tuple[tuple[int, ...], ...], dtype: dtype, page_size_padded: int | None = None)[source]#

Bases: KVCacheSpec

Specification for Mamba state-space model caching.

Mamba models use state-space representations instead of attention, requiring different cache structures for hidden states and convolutional states.

The cache stores multiple state tensors with different shapes, all packed into a single page-based allocation.

shapes#

Shapes of state tensors to cache. Each inner tuple defines one state tensor’s shape.

Type

tuple[tuple[int, …], …]

dtype#

Data type for state tensors.

Type

jax.numpy.dtype

page_size_padded#

Optional padded page size for alignment. If set, pages are padded to this size.

Type

int | None

num_elements#

Total number of elements across all shapes. Calculated automatically in __post_init__.

Type

int

dtype: dtype#
max_memory_usage_bytes(*args, **kwargs) int[source]#

Calculate maximum memory for Mamba state cache.

Mamba caches have fixed size per layer regardless of sequence length, as they maintain a constant-size state representation.

Parameters
  • *args – Unused (for compatibility).

  • **kwargs – Unused (for compatibility).

Returns

Maximum memory in bytes (equals page_size_bytes).

Return type

int

property page_size_bytes: int#

Calculate page size for Mamba state cache in bytes.

Computes the memory needed to store all state tensors, optionally with padding for alignment.

Returns

Size of one state cache page in bytes.

Uses page_size_padded if specified, otherwise exact size based on num_elements * dtype_size.

Return type

int

Raises

AssertionError – If page_size_padded is less than required size.

page_size_padded: int | None = None#
shapes: tuple[tuple[int, ...], ...]#
property type_id: str#

Unique identifier for this cache specification type.

The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.

Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)

The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information

Returns

A unique string identifier for this cache type.

Format typically: “{strategy}_{params}_{size}”

Return type

str

Example

“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384

class easydel.layers.caching._specs.SlidingWindowSpec(page_size: int, num_kv_heads: int, head_size: int, dtype: dtype, use_mla: bool, sliding_window: int)[source]#

Bases: AttentionSpec

Specification for sliding window attention caching.

Implements a fixed-size sliding window where tokens can only attend to a limited number of previous tokens. This provides a good balance between memory efficiency and model capability.

The cache maintains a rolling buffer of the most recent tokens, discarding older tokens beyond the window size.

sliding_window#

Size of the sliding attention window. Each token attends to at most this many previous tokens.

Type

int

Constraints:
  • MLA optimization is not compatible with sliding windows

max_memory_usage_bytes(max_model_len: int, max_num_batched_tokens: int, **kwargs) int[source]#

Calculate maximum memory for sliding window cache.

Memory is bounded by window size plus current batch, with an extra page for boundary conditions.

Parameters
  • max_model_len (int) – Maximum sequence length (upper bound).

  • max_num_batched_tokens (int) – Maximum tokens processed per batch.

  • **kwargs – Additional arguments (unused).

Returns

Maximum memory in bytes.

Includes extra page for window boundary handling.

Return type

int

sliding_window: int#
property type_id: str#

Unique identifier for this cache specification type.

The type ID is used to determine cache compatibility when mixing different cache types in a model. Caches with the same type_id can share memory pools and be managed together.

Different type IDs should be returned for: - Different attention patterns (full vs sliding window) - Different cache sizes per token (varying head counts) - Different memory layouts (paged vs continuous)

The ID typically encodes: - Cache strategy name - Key configuration parameters - Memory layout information

Returns

A unique string identifier for this cache type.

Format typically: “{strategy}_{params}_{size}”

Return type

str

Example

“full_attention_128_16384” for full attention with page_size=128 and page_size_bytes=16384

easydel.layers.caching._specs.cdiv(a: int, b: int) int[source]#

Ceiling division: divide a by b and round up.

Computes the ceiling of a/b using integer arithmetic to avoid floating point operations. This is commonly used for calculating the number of pages needed for a given number of tokens.

Parameters
  • a (int) – Numerator (e.g., number of tokens)

  • b (int) – Denominator (e.g., page size)

Returns

The ceiling of a/b

Return type

int

Example

>>> cdiv(10, 3)  # 10 tokens, 3 per page
4  # Need 4 pages
>>> cdiv(9, 3)   # 9 tokens, 3 per page
3  # Need 3 pages