# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Request management for the eSurge engine.
Defines the core request structures and status tracking for managing
inference requests throughout their lifecycle.
Classes:
EngineRequest: Main request object for tracking generation
EngineRequestStatus: Enum of request statuses
Example:
>>> request = EngineRequest(
... request_id="req_123",
... prompt_token_ids=[1, 2, 3],
... sampling_params=params,
... eos_token_id=2
... )
>>> request.status = EngineRequestStatus.RUNNING
"""
import enum
import time
from typing import Any
from ..sampling_params import SamplingParams
from .engine_types import EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason
from .utils import ConstantList
[docs]class EngineRequest:
"""Request object for tracking generation through the engine.
Manages the state and metadata of a single inference request,
including tokens, sampling parameters, and execution status.
Attributes:
request_id: Unique identifier for the request.
prompt_token_ids: Input token IDs.
sampling_params: Parameters controlling generation.
eos_token_id: End-of-sequence token ID.
client_index: Index of the client making request.
arrival_time: Timestamp when request arrived.
priority: Request priority for scheduling.
parent_request_id: ID of parent request for n>1 sampling (None for n=1).
sample_index: Index of this sample (0 to n-1) for n>1 sampling.
status: Current request status.
events: List of events during processing.
stop_reason: Reason for stopping generation.
Example:
>>> request = EngineRequest(
... request_id="req_123",
... prompt_token_ids=[1, 2, 3],
... sampling_params=sampling_params,
... eos_token_id=2
... )
"""
def __init__(
self,
request_id: str,
prompt_token_ids: list[int],
sampling_params: SamplingParams | None,
eos_token_id: int | None,
client_index: int = 0,
arrival_time: float | None = None,
priority: int = 0,
parent_request_id: str | None = None,
sample_index: int = 0,
) -> None:
"""Initialize EngineRequest.
Args:
request_id: Unique request identifier.
prompt_token_ids: Input token IDs.
sampling_params: Generation parameters.
eos_token_id: End-of-sequence token.
client_index: Client index.
arrival_time: Request arrival time.
priority: Request priority.
parent_request_id: Parent request ID for n>1 sampling.
sample_index: Sample index (0 to n-1) for n>1 sampling.
"""
self.request_id = request_id
self.client_index = client_index
self.priority = priority
self.parent_request_id = parent_request_id
self.sample_index = sample_index
self.sampling_params = sampling_params
self.eos_token_id = eos_token_id
self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.status = EngineRequestStatus.WAITING
self.events: list[EngineCoreEvent] = []
self.stop_reason: int | str | None = None
self.kv_transfer_params: dict[str, Any] | None = None
if sampling_params is not None:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.extra_args is not None:
self.kv_transfer_params = sampling_params.extra_args.get("kv_transfer_params")
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.num_output_placeholders = 0
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
self.num_cached_tokens = -1
self.num_nans_in_logits = 0
[docs] @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "EngineRequest":
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
priority=request.priority,
)
[docs] def append_output_token_ids(
self,
token_ids: int | list[int],
) -> None:
if isinstance(token_ids, int):
self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
@property
def num_tokens_with_spec(self) -> int:
return len(self._all_token_ids) + len(self.spec_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
[docs] def is_finished(self) -> bool:
return EngineRequestStatus.is_finished(self.status)
[docs] def get_finished_reason(self) -> FinishReason | None:
return EngineRequestStatus.get_finished_reason(self.status)
[docs] def record_event(
self,
event_type: EngineCoreEventType,
timestamp: float | None = None,
) -> None:
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
[docs] def take_events(self) -> list[EngineCoreEvent] | None:
if not self.events:
return None
events, self.events = self.events, []
return events
[docs]class EngineRequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
def __str__(self):
return self.name
[docs] @staticmethod
def is_finished(status: "EngineRequestStatus") -> bool:
return status > EngineRequestStatus.PREEMPTED
[docs] @staticmethod
def get_finished_reason(status: "EngineRequestStatus") -> FinishReason | None:
return _FINISHED_REASON_MAP.get(status)
_FINISHED_REASON_MAP = {
EngineRequestStatus.FINISHED_STOPPED: FinishReason.STOP,
EngineRequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
EngineRequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
EngineRequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
}