Source code for easydel.inference.esurge.scheduler.scheduler

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import typing
from collections import defaultdict
from collections.abc import Iterable

from eformer.loggings import get_logger

from ..config import Config
from ..core.interface import CacheGroupsConfig
from ..core.manager import CacheManager
from ..engine_types import EngineCoreOutput, EngineCoreOutputs
from ..metrics import get_metrics_collector
from ..outputs import ModelRunnerOutput
from ..request import EngineRequest, EngineRequestStatus
from .interface import SchedulerInterface
from .output import CachedRequestData, NewRequestData, SchedulerOutput
from .request_queue import SchedulingPolicy, create_request_queue
from .token_budget import TokenBudgetManager
from .utils import check_stop

if typing.TYPE_CHECKING:
    from ..runners.model_runner import eSurgeRunner

logger = get_logger("eSurgeScheduler")


[docs]class Scheduler(SchedulerInterface): def __init__( self, config: Config, kv_cache_config: CacheGroupsConfig, include_finished_set: bool = False, max_num_seq_buckets: list[int] | None = None, ) -> None: self.config = config self.scheduler_config = config.scheduler_config self.cache_config = config.cache_config self.kv_cache_config = kv_cache_config self.finished_req_ids_dict: dict[int, set[str]] | None = defaultdict(set) if include_finished_set else None self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len num_pages = self.cache_config.num_pages assert num_pages is not None and num_pages > 0 self.page_size = self.cache_config.page_size safety_margin = self.scheduler_config.token_safety_margin if safety_margin is None: self._token_budget_manager = None else: self._token_budget_manager = TokenBudgetManager( max_batch_tokens=self.max_num_scheduled_tokens, page_size=self.page_size, safety_margin_tokens=safety_margin, ) self.requests: dict[str, EngineRequest] = {} if self.scheduler_config.policy == "priority": self.policy = SchedulingPolicy.PRIORITY elif self.scheduler_config.policy == "fcfs": self.policy = SchedulingPolicy.FCFS else: raise ValueError(f"Unknown scheduling policy: {self.scheduler_config.policy}") self.waiting = create_request_queue(self.policy) self.running: list[EngineRequest] = [] self.finished_req_ids: set[str] = set() self.finished_recving_kv_req_ids: set[str] = set() speculative_config = config.speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens self.kv_cache_manager = CacheManager( num_pages=num_pages, kv_cache_groups=kv_cache_config.kv_cache_groups, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, ) buckets = max_num_seq_buckets or list(self.scheduler_config.max_num_seq_buckets or ()) if not buckets: buckets = [self.max_num_running_reqs] buckets = sorted({int(b) for b in buckets if b > 0}) if not buckets: buckets = [self.max_num_running_reqs] if buckets[-1] != self.max_num_running_reqs: buckets.append(self.max_num_running_reqs) self.max_num_seq_buckets = buckets self._current_seq_bucket = self._select_seq_bucket(0)
[docs] @classmethod def from_runner( cls, runner: eSurgeRunner, max_num_batched_tokens: int | None = None, enable_prefix_caching: bool = True, ) -> Scheduler: """Create a Scheduler instance from an eSurgeRunner. This method automatically detects the model's attention types (full, sliding window, chunked) from the model config and creates appropriate cache specifications. Args: runner: The eSurgeRunner instance. max_num_batched_tokens: Maximum tokens per batch. Defaults to max_model_len. enable_prefix_caching: Enable prefix caching for faster inference. """ from ..config import CacheConfig, SchedulerConfig from ..core.interface import create_kv_cache_specs_from_config metadata = runner.metadata model_config = runner.model.config if max_num_batched_tokens is None: max_num_batched_tokens = runner.max_model_len kv_cache_groups = create_kv_cache_specs_from_config( config=model_config, page_size=metadata.page_size, num_kv_heads=metadata.num_kv_heads, head_size=metadata.k_headdim, dtype=runner.executor_manager.kv_pages.views[-1].kv_pages.dtype, use_mla=False, ) return Scheduler( config=Config( scheduler_config=SchedulerConfig( max_num_seqs=runner.max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=runner.max_model_len, max_num_seq_buckets=tuple(runner.max_num_seq_buckets), ), cache_config=CacheConfig( num_pages=metadata.num_pages, page_size=metadata.page_size, enable_prefix_caching=enable_prefix_caching, ), ), kv_cache_config=CacheGroupsConfig(num_pages=metadata.num_pages, kv_cache_groups=kv_cache_groups), )
def _select_seq_bucket(self, num_reqs: int) -> int: if num_reqs <= 0: return self.max_num_seq_buckets[0] for bucket in self.max_num_seq_buckets: if num_reqs <= bucket: return bucket return self.max_num_seq_buckets[-1] def _ensure_capacity(self, desired_running: int) -> bool: bucket = self._select_seq_bucket(desired_running) self._current_seq_bucket = bucket return desired_running <= bucket
[docs] def schedule(self) -> SchedulerOutput: import time schedule_start_time = time.time() scheduled_new_reqs: list[EngineRequest] = [] scheduled_resumed_reqs: list[EngineRequest] = [] scheduled_running_reqs: list[EngineRequest] = [] preempted_reqs: list[EngineRequest] = [] req_to_new_page_ids: dict[str, tuple[list[int], ...]] = {} num_scheduled_tokens: dict[str, int] = {} if self._token_budget_manager: token_budget = self._token_budget_manager.begin_cycle(self.kv_cache_manager, len(self.running)) else: token_budget = self.max_num_scheduled_tokens scheduled_spec_decode_tokens: dict[str, list[int]] = {} req_index = 0 self._ensure_capacity(len(self.running)) while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_new_tokens = request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens) if num_new_tokens == 0: req_index += 1 continue preemption_attempts = 0 max_preemption_attempts = len(self.running) + 1 # Allow one full cycle plus one while True: new_pages = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens ) if new_pages is None: preemption_attempts += 1 if preemption_attempts >= max_preemption_attempts: # Cannot allocate even after preempting all requests logger.warning( f"Cannot allocate {num_new_tokens} tokens for request {request.request_id} " f"after {preemption_attempts} preemption attempts. Skipping." ) can_schedule = False break if self.policy == SchedulingPolicy.PRIORITY: preempted_req = max( self.running, key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) else: preempted_req = self.running.pop() self.kv_cache_manager.free(preempted_req) preempted_req.status = EngineRequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: can_schedule = False break else: can_schedule = True break if not can_schedule: break assert new_pages is not None scheduled_running_reqs.append(request) req_to_new_page_ids[request.request_id] = new_pages.get_page_ids() num_scheduled_tokens[request.request_id] = num_new_tokens if self._token_budget_manager: self._token_budget_manager.consume(num_new_tokens) token_budget = self._token_budget_manager.remaining else: token_budget -= num_new_tokens if token_budget <= 0: req_index += 1 break req_index += 1 if request.spec_token_ids: num_scheduled_spec_tokens = num_new_tokens + request.num_computed_tokens - request.num_tokens if num_scheduled_spec_tokens > 0: del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids skipped_waiting_requests = create_request_queue(self.policy) if not preempted_reqs: while self.waiting and token_budget > 0: if not self._ensure_capacity(len(self.running) + 1): break request = self.waiting.peek_request() if request.status == EngineRequestStatus.WAITING_FOR_REMOTE_KVS: request.status = EngineRequestStatus.WAITING if request.status == EngineRequestStatus.WAITING_FOR_FSM: structured_output_req = request.structured_output_request if structured_output_req and structured_output_req.grammar: request.status = EngineRequestStatus.WAITING else: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue num_external_computed_tokens = 0 load_kv_async = False if request.num_computed_tokens == 0: new_computed_pages, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_pages(request) num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens else: new_computed_pages = self.kv_cache_manager.create_empty_page_list() num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens if load_kv_async: assert num_external_computed_tokens > 0 num_new_tokens = 0 else: num_new_tokens = request.num_tokens - num_computed_tokens if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold if not self.scheduler_config.chunked_prefill_enabled and num_new_tokens > token_budget: # If the request is larger than the max batch size, we MUST allow it if the batch is empty, # otherwise it will never run. # If it's larger than available memory (reflected in token_budget via capacity), # allocate_slots will fail anyway. is_inherently_too_large = num_new_tokens > self.max_num_scheduled_tokens is_batch_empty = ( len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) == 0 ) if is_inherently_too_large and is_batch_empty: # Allow it to proceed to allocation pass else: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 new_pages = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, num_new_local_computed_tokens, new_computed_pages, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_pages=load_kv_async, ) if new_pages is None: break request = self.waiting.pop_request() if load_kv_async: skipped_waiting_requests.prepend_request(request) request.status = EngineRequestStatus.WAITING_FOR_REMOTE_KVS continue req_index += 1 self.running.append(request) if request.status == EngineRequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == EngineRequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: raise RuntimeError(f"Invalid request status: {request.status}") req_to_new_page_ids[request.request_id] = self.kv_cache_manager.get_page_ids(request.request_id) if self._token_budget_manager: self._token_budget_manager.consume(num_new_tokens) token_budget = self._token_budget_manager.remaining else: token_budget -= num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens request.status = EngineRequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens if request.num_cached_tokens < 0: request.num_cached_tokens = num_computed_tokens if token_budget <= 0: break if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) self._ensure_capacity(len(self.running)) total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self._current_seq_bucket assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running) num_common_prefix_pages = [0] * len(self.kv_cache_config.kv_cache_groups) scheduled_req_count = len(num_scheduled_tokens) if scheduled_req_count > 0: if scheduled_running_reqs: representative_req = scheduled_running_reqs[0] elif scheduled_resumed_reqs: representative_req = scheduled_resumed_reqs[0] elif scheduled_new_reqs: representative_req = scheduled_new_reqs[0] else: representative_req = None if representative_req is not None: num_common_prefix_pages = self.kv_cache_manager.get_num_common_prefix_pages( representative_req, scheduled_req_count ) new_reqs_data = [ NewRequestData.from_request(req, req_to_new_page_ids[req.request_id]) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( scheduled_running_reqs, scheduled_resumed_reqs, num_scheduled_tokens, scheduled_spec_decode_tokens, req_to_new_page_ids, ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, num_common_prefix_pages=num_common_prefix_pages, finished_req_ids=self.finished_req_ids, suggested_bucket=self._current_seq_bucket, # Hint for runner's buffer selection async_scheduling=self.scheduler_config.async_scheduling, # Pass async config to runner ) self._update_after_schedule(scheduler_output) # Log scheduler metrics schedule_time = time.time() - schedule_start_time metrics_collector = get_metrics_collector() if metrics_collector: metrics_collector.record_scheduler_metrics( num_waiting=len(self.waiting), num_running=len(self.running), num_scheduled_tokens=scheduler_output.total_num_scheduled_tokens, num_preempted=len(preempted_reqs), batch_size=len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs), schedule_time=schedule_time, ) # Log cache metrics cache_manager = self.kv_cache_manager total_pages = cache_manager.num_pages used_pages = total_pages - cache_manager.page_pool.get_num_free_pages() num_cached_pages = len(cache_manager.page_pool.cached_page_hash_to_page) cache_hit_rate = num_cached_pages / max(total_pages, 1) metrics_collector.record_cache_metrics( total_pages=total_pages, used_pages=used_pages, cache_hit_rate=cache_hit_rate, ) return scheduler_output
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: num_scheduled_tokens = scheduler_output.num_scheduled_tokens for req_id, num_scheduled_token in num_scheduled_tokens.items(): request = self.requests[req_id] request.num_computed_tokens += num_scheduled_token self.finished_req_ids = set() def _make_cached_request_data( self, running_reqs: list[EngineRequest], resumed_reqs: list[EngineRequest], num_scheduled_tokens: dict[str, int], spec_decode_tokens: dict[str, list[int]], req_to_new_page_ids: dict[str, tuple[list[int], ...]], ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] new_page_ids: list[tuple[list[int], ...]] = [] num_computed_tokens: list[int] = [] for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) num_tokens = num_scheduled_tokens[req_id] - len(spec_decode_tokens.get(req_id, ())) token_ids = req.all_token_ids[req.num_computed_tokens : req.num_computed_tokens + num_tokens] new_token_ids.append(token_ids) new_page_ids.append(req_to_new_page_ids[req_id]) num_computed_tokens.append(req.num_computed_tokens) resumed_from_preemption = [False] * len(running_reqs) resumed_from_preemption += [True] * len(resumed_reqs) return CachedRequestData( req_ids=req_ids, resumed_from_preemption=resumed_from_preemption, new_token_ids=new_token_ids, new_page_ids=new_page_ids, num_computed_tokens=num_computed_tokens, )
[docs] def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_nans_in_logits = model_runner_output.num_nans_in_logits outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) stopped_running_reqs: set[EngineRequest] = set() stopped_preempted_reqs: set[EngineRequest] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 request = self.requests.get(req_id) if request is None: continue req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] stopped = False new_token_ids = generated_token_ids status_before_stop = request.status if new_token_ids: new_token_ids, stopped = self._update_request_with_output(request, new_token_ids) if stopped: self._free_request(request) if status_before_stop == EngineRequestStatus.RUNNING: stopped_running_reqs.add(request) else: stopped_preempted_reqs.add(request) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, events=request.take_events(), num_cached_tokens=request.num_cached_tokens, ) ) assert not prompt_logprobs_tensors if stopped_running_reqs: self.running = [req for req in self.running if req not in stopped_running_reqs] if stopped_preempted_reqs: self.waiting.remove_requests(stopped_preempted_reqs) engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()} finished_req_ids = self.finished_req_ids_dict if finished_req_ids: for client_index, finished_set in finished_req_ids.items(): if (eco := engine_core_outputs.get(client_index)) is not None: eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set) finished_req_ids.clear() return engine_core_outputs
def _update_request_with_output( self, request: EngineRequest, new_token_ids: list[int], ) -> tuple[list[int], bool]: stopped = False for num_new, output_token_id in enumerate(new_token_ids, 1): request.append_output_token_ids(output_token_id) stopped = check_stop(request, self.max_model_len) if stopped: del new_token_ids[num_new:] break return new_token_ids, stopped
[docs] def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting)
[docs] def add_request(self, request: EngineRequest) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request
[docs] def finish_requests( self, request_ids: str | Iterable[str], finished_status: EngineRequestStatus, ) -> None: """Handles the finish signal from outside the scheduler. For example, the API server can abort a request when the client disconnects. """ assert EngineRequestStatus.is_finished(finished_status) if isinstance(request_ids, str): request_ids = (request_ids,) else: request_ids = set(request_ids) running_requests_to_remove = [] waiting_requests_to_remove = [] valid_requests = [] for req_id in request_ids: request = self.requests.get(req_id) if request is None: continue valid_requests.append(request) if request.status == EngineRequestStatus.RUNNING: running_requests_to_remove.append(request) else: waiting_requests_to_remove.append(request) for request in running_requests_to_remove: self.running.remove(request) if waiting_requests_to_remove: self.waiting.remove_requests(waiting_requests_to_remove) for request in valid_requests: request.status = finished_status self._free_request(request)
def _free_request(self, request: EngineRequest): assert request.is_finished() request_id = request.request_id self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) self._free_pages(request) def _free_pages(self, request: EngineRequest): assert request.is_finished() self.kv_cache_manager.free(request) self.kv_cache_manager.free_page_hashes(request) del self.requests[request.request_id]
[docs] def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running)
[docs] def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0
[docs] def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache()
[docs] def shutdown(self) -> None: ...