# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import dataclasses
import queue
import time
import typing as tp
import jax
from easydel.inference.utilities import SamplingParams
from easydel.layers.caching.paged_attention import (
HBMPageManager,
PagedAttentionCache,
)
from easydel.utils.helpers import get_logger
from .engines import (
oDriver,
oEngine,
vDriver,
vEngine,
)
from .utils import (
ActiveRequest,
ActiveRequestMetadata,
AsyncMultifuture,
ReturnSample,
is_byte_token,
text_tokens_to_string,
)
if tp.TYPE_CHECKING:
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.utils import ProcessingClassType
else:
ProcessingClassType = tp.Any
EasyDeLBaseModule = tp.Any
logger = get_logger("vSurge")
[docs]@dataclasses.dataclass
class vSurgeRequest:
"""Represents a request specifically for text completion."""
prompt: str
max_tokens: int
top_p: float = 1.0
top_k: int = 0
min_p: float = 0.0
temperature: float = 0.7
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
metadata: vSurgeMetadata | None = None
is_client_side_tokenization: bool = False
[docs] @classmethod
def from_sampling_params(cls, prompt: str, sampling_params: SamplingParams):
return vSurgeRequest(
prompt=prompt,
max_tokens=sampling_params.max_tokens,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
min_p=sampling_params.min_p,
temperature=sampling_params.temperature,
presence_penalty=sampling_params.presence_penalty,
frequency_penalty=sampling_params.frequency_penalty,
repetition_penalty=sampling_params.repetition_penalty,
)
def __post_init__(self):
"""Ensures metadata is initialized."""
if self.metadata is None:
self.metadata = vSurgeMetadata()
self.is_client_side_tokenization = False
assert isinstance(self.prompt, str), "prompt should be a single string"
[docs]class vSurge:
"""Orchestrates the interaction between client requests and the vDriver."""
def __init__(
self,
driver: tp.Union[vDriver, oDriver],
vsurge_name: str | None = None,
):
"""Initializes the vSurge.
Args:
driver: The vDriver instance that manages the underlying inference
engines and processing threads.
"""
self._driver = driver
self._vsurge_name = vsurge_name or driver.driver_name
[docs] def compile(self):
self.driver.compile()
@property
def vsurge_name(self):
return self._vsurge_name
@property
def driver(self):
"""Provides access to the underlying vDriver instance."""
return self._driver
@property
def processor(self) -> ProcessingClassType:
"""Returns the processor/tokenizer associated with the underlying driver."""
return self.driver.processor
[docs] def start(self):
return self.driver.start()
[docs] def stop(self):
return self.driver.stop()
[docs] @classmethod
def create_odriver(
cls,
model: EasyDeLBaseModule,
processor: ProcessingClassType,
storage: tp.Optional[PagedAttentionCache] = None,
manager: tp.Optional[HBMPageManager] = None,
page_size: int = 128,
hbm_utilization: float = 0.6,
max_concurrent_prefill: int | None = None,
max_concurrent_decodes: int | None = None,
prefill_lengths: int | None = None,
max_prefill_length: int | None = None,
max_length: int | None = None,
seed: int = 894,
vsurge_name: str | None = None,
) -> vSurge:
max_length = max_length or 8192
max_concurrent_prefill = max_concurrent_prefill or jax.device_count()
max_concurrent_decodes = max_concurrent_decodes or jax.device_count()
metadata = model.create_paged_metadata(
page_size=page_size,
batch_size=max_concurrent_decodes,
max_sequences=max_length,
dtype=model.dtype,
hbm_utilization=hbm_utilization,
)
if storage is None:
storage = model.init_pages(metadata=metadata)
if manager is None:
manager = HBMPageManager(metadata=metadata)
return vSurge(
driver=oDriver(
oEngine(
model=model,
processor=processor,
storage=storage,
manager=manager,
max_concurrent_decodes=max_concurrent_decodes,
max_concurrent_prefill=max_concurrent_prefill,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
max_length=max_length,
batch_size=max_concurrent_decodes,
seed=seed,
)
),
vsurge_name=vsurge_name,
)
[docs] @classmethod
def create_vdriver(
cls,
model: EasyDeLBaseModule,
processor: ProcessingClassType,
max_concurrent_decodes: int | None = None,
prefill_lengths: int | None = None,
max_prefill_length: int | None = None,
max_length: int | None = None,
seed: int = 894,
vsurge_name: str | None = None,
) -> vSurge:
"""Creates a new instance of vSurge with configured vDriver and vEngines.
This class method provides a convenient way to instantiate the vSurge
by setting up the necessary prefill and decode engines with the provided
model, processor, and configuration parameters.
Args:
model: The EasyDeLBaseModule instance representing the model.
processor: The tokenizer/processor instance.
max_concurrent_decodes: Maximum number of concurrent decode requests
the decode engine can handle.
prefill_lengths: A list of prefill lengths to compile for the
prefill engine.
max_prefill_length: The maximum prefill length for the prefill engine.
max_length: The maximum total sequence length for both engines.
seed: The random seed for reproducibility.
vsurge_name: An optional name for the vsurge.
Returns:
A new instance of vSurge.
"""
return vSurge(
driver=vDriver(
prefill_engines=vEngine(
model=model,
processor=processor,
max_concurrent_prefill=1,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
max_length=max_length,
seed=seed,
),
decode_engines=vEngine(
model=model,
processor=processor,
max_concurrent_decodes=max_concurrent_decodes,
max_length=max_length,
seed=seed,
),
),
vsurge_name=vsurge_name,
)
[docs] def count_tokens(self, text_or_conversation: tp.Union[str, list]) -> int:
"""Counts the number of tokens in a given string or conversation list.
Uses the underlying driver's processor to tokenize the input and returns
the count of tokens.
Args:
text_or_conversation: Either a single string or a list of
message dictionaries (like OpenAI chat format).
Returns:
The total number of tokens in the input.
Raises:
ValueError: If the input type is invalid or tokenization fails.
"""
try:
if isinstance(text_or_conversation, str):
# Tokenize a single string
return len(self.processor(text=text_or_conversation)["input_ids"])
elif isinstance(text_or_conversation, list):
if hasattr(self.processor, "apply_chat_template"):
tokenized = self.processor.apply_chat_template(
conversation=text_or_conversation,
tokenize=True,
add_generation_prompt=False,
)
return len(
tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized
)
else:
full_text = " ".join(
[
msg.get("content", "")
for msg in text_or_conversation
if isinstance(msg.get("content"), str)
]
)
return len(self.processor(text=full_text)["input_ids"])
else:
raise ValueError(
f"Unsupported input type for token counting: {type(text_or_conversation)}"
)
except Exception as e:
logger.error(f"Error during token counting: {e}")
raise ValueError(f"Failed to count tokens: {e}") from e
[docs] def process_client_side_tokenization_response(self, response: list[ReturnSample]):
"""Processes responses when tokenization is handled client-side.
In this case, the response items (ReturnSample) are typically yielded
directly without further server-side processing like detokenization
or buffering.
Args:
response: A list of ReturnSample objects from a single generation step.
Returns:
The input list of ReturnSample objects, unchanged.
"""
samples = []
for sample in response:
samples.append(sample)
return samples
[docs] def should_buffer_response(self, response: list[ReturnSample]) -> bool:
"""Determines if a response needs buffering for server-side detokenization.
Buffering is needed if any sample in the response ends with a byte token
(e.g., "<0xAB>"), as this indicates an incomplete multi-byte character
that requires subsequent tokens for proper decoding.
Args:
response: A list of ReturnSample objects from a single generation step.
Returns:
True if buffering is required, False otherwise.
"""
for item in response:
# Check if the text list is not empty and the last item is a byte token
if item.text and is_byte_token(item.text[-1]):
return True
return False # Return False if no byte tokens found at the end
[docs] def process_server_side_tokenization_response(
self,
response: list[ReturnSample],
buffered_response_list: list[list[ReturnSample]],
) -> list[ReturnSample]:
"""Processes responses when tokenization/detokenization is server-side.
Combines the text and token IDs from the current response and any
buffered previous responses for each sample. It then uses the metrics
(TPS, generated token count) from the *latest* response in the sequence
for the final output.
Args:
response: The list of ReturnSample objects from the current step.
buffered_response_list: A list containing lists of ReturnSample objects
from previous steps that were buffered.
Returns:
A list of tuples, where each tuple represents a completed sample and
contains: (decoded_string, all_token_ids, latest_tps, latest_num_generated_tokens).
"""
current_response_with_flushed_buffer = list(zip(*buffered_response_list, response))
current_response_with_flushed_buffer = tp.cast(
list[list[ReturnSample]],
current_response_with_flushed_buffer,
)
samples = []
for sample_responses in current_response_with_flushed_buffer:
text = []
token_ids = []
latest_response = sample_responses[-1]
tps = latest_response.tokens_per_second
num_gen_tokens = latest_response.num_generated_tokens
for resp in sample_responses:
text.extend(resp.text)
token_ids.extend(resp.token_ids)
samples.append(
ReturnSample(
text=text_tokens_to_string(text),
token_ids=token_ids,
tokens_per_second=tps,
num_generated_tokens=num_gen_tokens,
)
)
return samples
[docs] async def complete(
self,
request: vSurgeRequest,
) -> tp.AsyncGenerator[tp.List[ReturnSample]]:
"""Initiates and streams the results of a text completion request.
Creates an `ActiveRequest` using the plain prompt from the
`vSurgeRequest`, places it on the driver's prefill queue,
and then asynchronously iterates through the results provided by the
`ActiveRequest`'s `return_channel`.
It handles both client-side and server-side tokenization scenarios,
buffering and processing results appropriately before yielding them.
Args:
request: The vSurgeRequest containing the prompt and
generation parameters.
Yields:
Processed generation results, similar to the `decode` method. The format
depends on the tokenization mode.
Raises:
RuntimeError: If the prefill queue is full when trying to place the
request.
"""
return_channel = AsyncMultifuture()
active_request = ActiveRequest(
max_tokens=request.max_tokens,
prefill_content=request.prompt,
is_client_side_tokenization=request.is_client_side_tokenization,
return_channel=return_channel,
top_p=request.top_p,
top_k=request.top_k,
min_p=request.min_p,
temperature=request.temperature,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
metadata=ActiveRequestMetadata(
start_time=request.metadata.start_time if request.metadata else time.time(),
prefill_enqueue_time=time.perf_counter(),
),
)
try:
self._driver.place_request_on_prefill_queue(active_request)
except queue.Full as e:
raise RuntimeError("Prefill queue is full") from e
buffered_response_list = []
async for response in active_request.return_channel:
response = tp.cast(list[ReturnSample], response)
if request.is_client_side_tokenization:
yield self.process_client_side_tokenization_response(response)
else:
if self.should_buffer_response(response):
buffered_response_list.append(response)
else:
yield self.process_server_side_tokenization_response(
response,
buffered_response_list,
)
buffered_response_list = []
if not request.is_client_side_tokenization and buffered_response_list:
last_real_response = buffered_response_list[-1]
dummy_response = [
ReturnSample(
text=[],
token_ids=[],
tokens_per_second=s.tokens_per_second,
num_generated_tokens=s.num_generated_tokens,
)
for s in last_real_response
]
yield self.process_server_side_tokenization_response(
dummy_response,
buffered_response_list,
)
[docs] async def generate(
self,
prompts: tp.Union[str, tp.Sequence[str]],
sampling_params: tp.Optional[
tp.Union[SamplingParams, tp.Sequence[SamplingParams]]
] = None,
stream: bool = False,
) -> tp.Union[tp.List[ReturnSample], tp.AsyncGenerator[tp.List[ReturnSample]]]:
"""Generates text completions concurrently for the given prompts.
Args:
prompts: A single prompt string or a list of prompt strings.
sampling_params: A single SamplingParams object or a list of
SamplingParams objects. If None, default SamplingParams will be used.
If a single SamplingParams object is provided with multiple prompts,
it will be applied to all prompts. If a list is provided, it must
have the same length as the prompts list.
stream: If True, yields results (List[ReturnSample]) from *any* request
as they become available. The list corresponds to one generation step
from one request. If False, waits for all requests to complete and
returns a list containing one aggregated ReturnSample per prompt.
Returns:
If stream is True: An async generator yielding lists of ReturnSample as
steps complete across concurrent requests.
If stream is False: A list of aggregated ReturnSample objects, one for
each input prompt, after all requests have finished.
Raises:
ValueError: If the lengths of prompts and sampling_params lists mismatch.
RuntimeError: If the underlying driver's queue is full.
"""
if isinstance(prompts, str):
prompts = [prompts]
if sampling_params is not None and not isinstance(
sampling_params, SamplingParams
):
raise ValueError(
"If prompts is a single string, sampling_params must be a single SamplingParams object or None."
)
sampling_params = [sampling_params if sampling_params else SamplingParams()]
elif isinstance(prompts, tp.Sequence):
if sampling_params is None:
sampling_params = [SamplingParams()] * len(prompts)
elif isinstance(sampling_params, SamplingParams):
sampling_params = [sampling_params] * len(prompts)
elif isinstance(sampling_params, tp.Sequence):
if len(prompts) != len(sampling_params):
raise ValueError("Lengths of prompts and sampling_params lists must match.")
else:
raise ValueError(
"sampling_params must be a SamplingParams object, a list of SamplingParams objects, or None."
)
else:
raise ValueError("prompts must be a string or a sequence of strings.")
if not prompts:
if stream:
async def empty_generator():
if False:
yield []
return empty_generator()
else:
return []
requests = [
vSurgeRequest.from_sampling_params(prompt=p, sampling_params=sp)
for p, sp in zip(prompts, sampling_params)
]
if stream:
return self._generate_stream(requests)
else:
return await self._generate_batch(requests)
async def _generate_stream(
self, requests: tp.List[vSurgeRequest]
) -> tp.AsyncGenerator[tp.List[ReturnSample]]:
"""Helper for concurrent streaming generation."""
q = asyncio.Queue()
tasks = set()
results_map: tp.Dict[asyncio.Task, tp.List[ReturnSample]] = {}
_SENTINEL = object()
async def _run_completion(request: vSurgeRequest):
"""Runs self.complete and puts results (or sentinel/exception) in queue."""
try:
async for result_step in self.complete(request):
await q.put(result_step)
except Exception as e:
await q.put(e)
finally:
await q.put(_SENTINEL)
for req in requests:
task = asyncio.create_task(_run_completion(req))
tasks.add(task)
results_map[task] = []
finished_tasks = 0
while finished_tasks < len(requests):
item = await q.get()
if item is _SENTINEL:
finished_tasks += 1
elif isinstance(item, Exception):
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise item
else:
yield item
q.task_done()
await asyncio.gather(*tasks, return_exceptions=True)
async def _generate_batch(
self,
requests: tp.List[vSurgeRequest],
) -> tp.List[ReturnSample]:
"""Helper for concurrent batch generation."""
async def _collect_all_steps(request: vSurgeRequest) -> list[list[ReturnSample]]:
"""Consumes the complete generator and returns all steps."""
all_steps = []
async for step_result in self.complete(request):
all_steps.append(step_result)
return all_steps
tasks = [asyncio.create_task(_collect_all_steps(req)) for req in requests]
try:
results_per_request: list[list[list[ReturnSample]]] = await asyncio.gather(*tasks)
except Exception as e:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise e
final_results: tp.List[ReturnSample] = []
for request_steps in results_per_request:
if not request_steps:
final_results.append(
ReturnSample(
text="", token_ids=[], tokens_per_second=0.0, num_generated_tokens=0
)
)
continue
aggregated_text = ""
aggregated_token_ids = []
total_tps = 0.0
final_num_generated_tokens = 0
num_steps = 0
for step_result in request_steps:
if step_result:
sample = step_result[0]
aggregated_text += sample.text
aggregated_token_ids.extend(sample.token_ids)
total_tps += sample.tokens_per_second
final_num_generated_tokens = sample.num_generated_tokens
num_steps += 1
average_tps = (total_tps / num_steps) if num_steps > 0 else 0.0
final_results.append(
ReturnSample(
text=aggregated_text,
token_ids=aggregated_token_ids,
tokens_per_second=average_tps,
num_generated_tokens=final_num_generated_tokens,
)
)
return final_results