Source code for easydel.inference.esurge.config

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Literal


[docs]@dataclass class SchedulerConfig: """Configuration for the request scheduler. Controls how requests are scheduled and batched for processing. Attributes: max_num_seqs: Maximum number of sequences running simultaneously. max_num_batched_tokens: Maximum tokens processed in a single batch. max_model_len: Maximum input length the model can handle. policy: Scheduling policy ('fcfs' for first-come-first-served, 'priority' for priority-based). long_prefill_token_threshold: Token count threshold for identifying long prefill requests. chunked_prefill_enabled: Enable chunked processing of long prefill requests. Example: >>> config = SchedulerConfig( ... max_num_seqs=16, ... max_num_batched_tokens=2048, ... max_model_len=8192, ... policy="priority" ... ) """ max_num_seqs: int """The maximum number of sequences running at the same time.""" max_num_batched_tokens: int """The maximum number of tokens to be processed in a single batch.""" max_model_len: int """The maximum length of the model's input.""" policy: Literal["priority", "fcfs"] = "fcfs" """The scheduling policy to use, such as 'priority' or 'fcfs'.""" long_prefill_token_threshold: int = 256 """A token threshold for handling long prefill requests.""" chunked_prefill_enabled: bool = False """A flag to enable or disable chunked prefilling.""" token_safety_margin: int | None = None """Reserved tokens per running request to prevent over-allocation.""" max_num_seq_buckets: tuple[int, ...] | None = None """Optional explicit request-capacity buckets (e.g., (8, 16, 32, 64)).""" async_scheduling: bool = False """Enable async token sampling to overlap with next forward pass (30-40% latency reduction).""" def __post_init__(self): """Validate configuration parameters.""" if self.max_num_seqs <= 0: raise ValueError(f"max_num_seqs must be positive, got {self.max_num_seqs}") if self.max_num_batched_tokens <= 0: raise ValueError(f"max_num_batched_tokens must be positive, got {self.max_num_batched_tokens}") if self.max_model_len <= 0: raise ValueError(f"max_model_len must be positive, got {self.max_model_len}") if self.max_num_batched_tokens > self.max_model_len: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) cannot exceed " f"max_model_len ({self.max_model_len})" ) if self.long_prefill_token_threshold < 0: raise ValueError( f"long_prefill_token_threshold must be non-negative, got {self.long_prefill_token_threshold}" ) if self.token_safety_margin is not None and self.token_safety_margin < 0: raise ValueError(f"token_safety_margin must be non-negative, got {self.token_safety_margin}") if self.max_num_seq_buckets is not None: if not self.max_num_seq_buckets: raise ValueError("max_num_seq_buckets cannot be empty") if any(b <= 0 for b in self.max_num_seq_buckets): raise ValueError(f"All bucket sizes must be positive, got {self.max_num_seq_buckets}")
[docs]@dataclass class CacheConfig: """Configuration for the KV (key-value) cache. Manages memory allocation and caching strategies for attention mechanisms. Attributes: num_pages: Number of GPU pages allocated for cache (None for automatic). page_size: Size of each cache page in tokens. enable_prefix_caching: Enable caching of common prefixes across requests. Example: >>> config = CacheConfig( ... num_pages=1000, ... page_size=16, ... enable_prefix_caching=True ... ) Note: Page-based allocation allows efficient memory management and sharing of cache blocks between sequences. """ num_pages: int | None """The number of GPU pages allocated for the cache.""" page_size: int """The size of each cache page.""" enable_prefix_caching: bool """A flag to enable or disable prefix caching.""" def __post_init__(self): """Validate configuration parameters.""" if self.page_size <= 0: raise ValueError(f"page_size must be positive, got {self.page_size}") if self.num_pages is not None and self.num_pages <= 0: raise ValueError(f"num_pages must be positive when specified, got {self.num_pages}")
[docs]@dataclass class SpeculativeConfig: """Configuration for speculative decoding. Attributes: num_speculative_tokens: Number of speculative tokens to generate. speculative_model: Path to the speculative model (e.g., Eagle model). """ num_speculative_tokens: int = 0 speculative_model: str | None = None
[docs] def use_eagle(self) -> bool: """Check if Eagle speculative decoding is enabled.""" return self.num_speculative_tokens > 0 and self.speculative_model is not None
[docs]@dataclass class Config: """Unified configuration for the eSurge engine. Combines scheduler and cache configurations into a single object. Attributes: scheduler_config: Configuration for request scheduling. cache_config: Configuration for KV cache management. speculative_config: Configuration for speculative decoding. Example: >>> config = Config( ... scheduler_config=SchedulerConfig(...), ... cache_config=CacheConfig(...), ... speculative_config=SpeculativeConfig(num_speculative_tokens=5) ... ) """ scheduler_config: SchedulerConfig """Nested configuration for the scheduler.""" cache_config: CacheConfig """Nested configuration for the cache.""" speculative_config: SpeculativeConfig | None = None """Nested configuration for speculative decoding."""