Source code for easydel.inference.esurge.core.interface

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict
from dataclasses import dataclass
from math import prod
from typing import TYPE_CHECKING, Self

from jax import numpy as jnp

from ..utils import cdiv, get_dtype_size

if TYPE_CHECKING:
    from easydel.infra.base_config import EasyDeLBaseConfig


[docs]@dataclass class CacheSpec: """ A base class for specifying the KV cache format of one layer. """ page_size: int @property def type_id(self) -> str: """ The type identifier of this KV cache. Return different strings for layers with different KV cache type (e.g., different number of tokens like full attention vs sliding window attention, different KV cache size per token like layers with different number of heads) Returns: The type identifier of this KV cache. """ raise NotImplementedError @property def page_size_bytes(self) -> int: """ The size of a page with `page_size` tokens in bytes. Returns: The page size """ raise NotImplementedError
[docs] def max_memory_usage_bytes(self, *args, **kwargs) -> int: """ The maximum possible memory usage of this KV cache in bytes. Returns: The KV cache size in bytes """ raise NotImplementedError
[docs] @classmethod def merge(cls, specs: list[Self]) -> Self: """ Merge a list of CacheSpec objects into a single CacheSpec object. """ assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( "All layers in the same KV cache group must share the same type_id." ) return copy.deepcopy(specs[0])
[docs]@dataclass class AttentionSpec(CacheSpec): num_kv_heads: int head_size: int dtype: jnp.dtype use_mla: bool @property def page_size_bytes(self) -> int: coef = 1 if self.use_mla else 2 return coef * self.page_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
[docs]@dataclass class FullAttentionSpec(AttentionSpec): sliding_window: int | None = None attention_chunk_size: int | None = None """ When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding window attention are regarded as full attention in KV cache manager (pages are allocated for all tokens), while computed as sliding window attention in model runner. In this case, we use FullAttentionSpec and record the sliding window size. Default to None for not using sliding window attention. """ @property def type_id(self) -> str: return f"full_attention_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(self, max_model_len, **kwargs) -> int: return cdiv(max_model_len, self.page_size) * self.page_size_bytes
[docs] @classmethod def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: if len(window_sizes) == 0: return None elif len(window_sizes) == 1: return window_sizes.pop() else: raise ValueError("All attention layers in the same KV cache group must have the same window size.")
[docs] @classmethod def merge(cls, specs: list[Self]) -> Self: """ Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object. """ merged_spec = super().merge(specs) sliding_window = set(spec.sliding_window for spec in specs if spec.sliding_window is not None) attention_chunk_size = set(spec.attention_chunk_size for spec in specs if spec.attention_chunk_size is not None) merged_spec.sliding_window = cls.merge_window_sizes(sliding_window) merged_spec.attention_chunk_size = cls.merge_window_sizes(attention_chunk_size) assert (merged_spec.sliding_window is not None) + (merged_spec.attention_chunk_size is not None) <= 1, ( "Model with both sliding window layers and chunked local attention layers is not supported." ) return merged_spec
[docs]@dataclass class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int @property def type_id(self) -> str: return f"local_attention_{self.attention_chunk_size}_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(self, max_model_len, max_num_batched_tokens, **kwargs) -> int: num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, max_model_len) return cdiv(num_tokens, self.page_size) * self.page_size_bytes
[docs]@dataclass class SlidingWindowSpec(AttentionSpec): sliding_window: int def __post_init__(self): assert not self.use_mla, "MLA is not supported for sliding window" @property def type_id(self) -> str: return f"sliding_window_{self.sliding_window}_{self.page_size}_{self.page_size_bytes}"
[docs] def max_memory_usage_bytes(self, max_model_len, max_num_batched_tokens, **kwargs) -> int: num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len) return (cdiv(num_tokens, self.page_size) + 1) * self.page_size_bytes
[docs]@dataclass class MambaSpec(CacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: jnp.dtype page_size_padded: int | None = None def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @property def type_id(self) -> str: return f"mamba_{self.shapes}_{self.dtype}" @property def page_size_bytes(self) -> int: page_size = self.num_elements * get_dtype_size(self.dtype) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded return page_size
[docs] def max_memory_usage_bytes(self, **kwargs) -> int: return self.page_size_bytes
[docs]@dataclass class CacheGroupSpec: """ Represents a group of model layers that share the same KV cache page table. These layers are regarded as one layer in the KV cache manager. """ kv_cache_spec: CacheSpec layer_names: list[str] | None = None
[docs]@dataclass class CacheGroupsConfig: """ The KV cache configuration of a model. """ num_pages: int kv_cache_groups: list[CacheGroupSpec]
[docs]def create_kv_cache_specs_from_config( config: "EasyDeLBaseConfig", page_size: int, num_kv_heads: int, head_size: int, dtype: jnp.dtype, use_mla: bool = False, ) -> list[CacheGroupSpec]: """Convert model config's get_mask_details() to CacheGroupSpec list. This function reads the attention mask details from the model configuration and creates appropriate cache specifications for each attention type. Layers with the same attention type are grouped together. Args: config: Model configuration with get_mask_details() method. page_size: Number of tokens per cache page. num_kv_heads: Number of key-value attention heads. head_size: Dimension of each attention head. dtype: Data type for cache tensors. use_mla: Whether to use Multi-head Latent Attention. Returns: List of CacheGroupSpec, one per attention type found in the config. Falls back to a single FullAttentionSpec if no mask details available. """ from easydel.infra.utils import AttnMaskType mask_details = config.get_mask_details() if hasattr(config, "get_mask_details") else None if not mask_details: return [ CacheGroupSpec( kv_cache_spec=FullAttentionSpec( page_size=page_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, use_mla=use_mla, ), layer_names=None, ) ] groups: dict[AttnMaskType, list[tuple[int, int | None, int | None]]] = defaultdict(list) for layer_idx, detail in mask_details.items(): groups[detail.mask_type].append((layer_idx, detail.size, detail.chunks)) specs: list[CacheGroupSpec] = [] for mask_type, layers in groups.items(): layer_names = [f"layer.{idx}" for idx, _, _ in layers] if mask_type == AttnMaskType.FULL: specs.append( CacheGroupSpec( kv_cache_spec=FullAttentionSpec( page_size=page_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, use_mla=use_mla, ), layer_names=layer_names, ) ) elif mask_type == AttnMaskType.SLIDING: sliding_window = layers[0][1] if sliding_window is None: raise ValueError(f"Sliding window size is required for sliding attention layers: {layer_names}") specs.append( CacheGroupSpec( kv_cache_spec=SlidingWindowSpec( page_size=page_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, use_mla=False, # MLA is not supported for sliding window sliding_window=sliding_window, ), layer_names=layer_names, ) ) elif mask_type == AttnMaskType.CHUNK: chunk_size = layers[0][2] if chunk_size is None: raise ValueError(f"Chunk size is required for chunked attention layers: {layer_names}") specs.append( CacheGroupSpec( kv_cache_spec=ChunkedLocalAttentionSpec( page_size=page_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, use_mla=use_mla, attention_chunk_size=chunk_size, ), layer_names=layer_names, ) ) else: raise ValueError(f"Unknown attention mask type: {mask_type}") return ( specs if specs else [ CacheGroupSpec( kv_cache_spec=FullAttentionSpec( page_size=page_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, use_mla=use_mla, ), layer_names=None, ) ] )