Source code for easydel.__init__.inference.vsurge.engines.vengine.driver

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import itertools
import queue
import time
import traceback
import typing as tp

import jax
import numpy as np
from jax import numpy as jnp

from easydel.inference.utilities import SamplingParams
from easydel.inference.vsurge.engines._abstract_driver import (
	AbstractDriver,
	ProcessingClassType,
)
from easydel.utils.helpers import get_logger

from ...utils import (
	ActiveRequest,
	ReturnSample,
	SafeThread,
	pad_tokens,
	process_result_tokens,
)
from .._utils import ResultTokens
from .engine import vEngine

if tp.TYPE_CHECKING:
	from easydel.infra.utils import ProcessingClassType
else:
	ProcessingClassType = tp.Any


logger = get_logger("vSurge-vDriver")


[docs]class vDriver(AbstractDriver): """Drives the engines.""" _prefill_engines: list[vEngine] _decode_engines: list[vEngine] _prefill_backlog: queue.Queue[ActiveRequest | None] _transfer_backlogs: list[queue.Queue[ActiveRequest]] = [] _decode_backlogs: dict[int, queue.Queue[ActiveRequest]] = {} _detokenize_backlogs: list[queue.Queue[ResultTokens]] = [] _decode_slots: list[queue.Queue[int]] = [] _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] _interleaved_mode: bool = False _detokenizing_blocks: int = 8 def __init__( self, prefill_engines: tp.Optional[list[vEngine] | vEngine] = None, decode_engines: tp.Optional[list[vEngine] | vEngine] = None, interleaved_mode: bool = False, detokenizing_blocks: int = 8, ): """Initializes the vDriver. Sets up the prefill and decode engines, backlogs (queues) for managing requests between stages, available slots for concurrent decoding, and starts the background threads for each stage (prefill, transfer, decode, detokenize). Args: prefill_engines: A single vEngine or a list of vEngines to be used for the prefill stage. Defaults to an empty list. decode_engines: A single vEngine or a list of vEngines to be used for the decode stage. Defaults to an empty list. interleaved_mode: A boolean flag indicating whether the driver should operate in interleaved mode (potentially optimizing for latency by prioritizing new requests). Defaults to False. """ if prefill_engines is None: prefill_engines = [] if decode_engines is None: decode_engines = [] if not isinstance(prefill_engines, list): prefill_engines = [prefill_engines] if not isinstance(decode_engines, list): decode_engines = [decode_engines] self._prefill_engines = prefill_engines self._decode_engines = decode_engines self._interleaved_mode = interleaved_mode self._detokenizing_blocks = detokenizing_blocks self._prefill_backlog = queue.Queue() self._transfer_backlogs = [ queue.Queue(1 if self._interleaved_mode else 4) for i in range(len(self._prefill_engines)) ] self._decode_backlogs = { idx: queue.Queue( 1 if self._interleaved_mode else engine.max_concurrent_decodes // 3 ) for idx, engine in enumerate(self._decode_engines) } self._detokenize_backlogs = [ queue.Queue(detokenizing_blocks) for _ in self._decode_engines ] self._decode_slots = [ queue.Queue(engine.max_concurrent_decodes) for engine in self._decode_engines ] _ = [ [self._decode_slots[idx].put(i) for i in range(engine.max_concurrent_decodes)] for idx, engine in enumerate(self._decode_engines) ] self._prefill_threads = [ SafeThread( target=functools.partial(self._prefill_thread, idx), name=f"prefill-{idx}", daemon=True, ) for idx in range(len(self._prefill_engines)) ] self._transfer_threads = [ SafeThread( target=functools.partial( self._transfer_thread, idx, ), name=f"transfer-{idx}", daemon=True, ) for idx in range(len(self._prefill_engines)) ] self._decode_threads = [ SafeThread( target=functools.partial( self._decode_thread, idx, ), name=f"decode-{idx}", daemon=True, ) for idx in range(len(self._decode_engines)) ] self.detokenize_threads = [ SafeThread( target=functools.partial( self._detokenize_thread, idx, ), name=f"detokenize-{idx}", ) for idx in range(len(self._decode_engines)) ] self.live = False
[docs] def start(self): if not self.live: self._all_threads = list( itertools.chain( self._prefill_threads, self._transfer_threads, self._decode_threads, self.detokenize_threads, ) ) self.live = True for t in self._all_threads: t.start()
# Add this method within the vDriver class
[docs] def submit_request(self, request: tp.Any): """Submits a new request to the driver's processing queue.""" # Assuming ActiveRequest is the expected type internally if not isinstance(request, ActiveRequest): # Or raise a more specific error raise TypeError("Request must be of type ActiveRequest") self.place_request_on_prefill_queue(request)
@property def driver_name(self): return self._get_model_name(self._decode_engines[-1].model)
[docs] def compile(self): """Compiles engines.""" try: for ( decode_engine, prefill_engine, ) in zip( self._decode_engines, self._prefill_engines, ): decode_state = decode_engine.init_decode_state() max_prefill_length = prefill_engine.max_prefill_length vals = prefill_engine.prefill_lengths[ : prefill_engine.prefill_lengths.index(max_prefill_length) ] + [max_prefill_length] for length in vals: padded_tokens = padded_valids = jnp.ones((1, length), "i4") logger.info(f"Compiling prefill-engine seqlen={length}") state_new, _ = prefill_engine.prefill( graphstate=prefill_engine.graphstate, graphothers=prefill_engine.graphothers, tokens=padded_tokens, valids=padded_valids, true_length=0, temperature=jnp.array([1], "f4"), top_p=jnp.array([1], "f4"), rngs=prefill_engine.prng_key, ) logger.info(f"Compiling decode-engine insert seqlen={length}") decode_state = decode_engine.insert(state_new, decode_state, 0) logger.info("Compiling decode-engine") decode_engine.decode( graphstate=decode_engine.graphstate, graphothers=decode_engine.graphothers, state=decode_state, rngs=decode_engine.prng_key, ) except Exception: traceback.print_exc() self.stop() exit(1)
[docs] def stop(self): """Stops the driver and all background threads.""" if self.live: self.live = False all_backlogs = list( itertools.chain( [self._prefill_backlog], self._transfer_backlogs, self._decode_backlogs.values(), self._detokenize_backlogs, ) ) while any(t.is_alive() for t in self._all_threads): for q in all_backlogs: while True: try: r = q.get_nowait() if r is None: continue elif isinstance(r, ActiveRequest): r.return_channel = None else: # detokenize backlog _, r = r if isinstance(r, ActiveRequest): r.return_channel = None except queue.Empty: break for q in all_backlogs: try: q.put_nowait(None) except queue.Full: pass for t in self._all_threads: t.join()
[docs] def get_total_concurrent_requests(self) -> int: """Gets the total number of concurrent requests the driver can handle.""" total_max_concurrent_decodes = sum( [e.max_concurrent_decodes for e in self._decode_engines] ) return total_max_concurrent_decodes
[docs] def place_request_on_prefill_queue(self, request: ActiveRequest): """Used to place new requests for prefilling and generation.""" self._prefill_backlog.put(request, block=False)
@property def processor(self) -> ProcessingClassType: # type:ignore """Returns the processor/tokenizer associated with the engines. Assumes all engines (prefill and decode) use the same processor. Raises an error if no engines are configured. """ if self._prefill_engines: return self._prefill_engines[0].processor elif self._decode_engines: return self._decode_engines[0].processor else: raise ValueError( "No engines configured for the vDriver, cannot determine processor." ) def _process_prefill_content( self, request: ActiveRequest, processor: ProcessingClassType, # type:ignore max_prefill_length: int, prefill_lengths: list[int], pad_token_id: int, ) -> tp.Tuple[tp.Tuple[jnp.ndarray, jnp.ndarray, int], SamplingParams]: """Tokenizes, pads, and prepares sampling parameters for a prefill request. Takes an `ActiveRequest`, extracts its `prefill_content` (which can be a string or pre-tokenized IDs), tokenizes it using the provided `processor` if necessary, pads the tokens to the appropriate length based on `max_prefill_length` and internal buckets, and constructs the `SamplingParams` object from the request's parameters. Args: request: The ActiveRequest containing the prompt and sampling settings. processor: The tokenizer/processor instance. max_prefill_length: The maximum allowed length for the prefill sequence. Returns: A tuple containing: - A nested tuple: (padded_tokens, padded_valids, padded_length) - The constructed SamplingParams object. """ content = request.prefill_content if isinstance(content, str): content = processor(text=content, return_tensors="np", return_attention_mask=True) tokens = jnp.array(content["input_ids"]) valids = jnp.array(content["attention_mask"]) else: tokens, valids = content return ( pad_tokens( tokens=tokens, valids=valids, pad_token_id=pad_token_id, max_prefill_length=max_prefill_length, prefill_lengths=prefill_lengths, right_padding=False, ), SamplingParams( max_tokens=0, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, repetition_penalty=request.repetition_penalty, min_p=request.min_p, top_p=request.top_p, temperature=request.temperature, ), ) def _prefill_thread(self, idx: int): """Thread which runs in the background performing prefills.""" logger.info(f"Spinning up prefill thread {idx}.") prefill_engine = self._prefill_engines[idx] processor = prefill_engine.processor while self.live: my_transfer_backlog = self._transfer_backlogs[idx] request = self._prefill_backlog.get(block=True) if request is None: break request.metadata.prefill_dequeue_time = time.perf_counter() ( ( padded_tokens, padded_valids, true_length, ), sampling_params, ) = self._process_prefill_content( request, processor, prefill_engine.max_prefill_length, prefill_engine.prefill_lengths, prefill_engine.pad_token_id, ) logger.info( f"Prefilling on prefill engine {idx} : " f"prefill queue size : {self._prefill_backlog.qsize()}, Token size {padded_valids.shape[-1]}", ) prefill_result, first_token = prefill_engine.prefill( graphstate=prefill_engine.graphstate, graphothers=prefill_engine.graphothers, tokens=padded_tokens, valids=padded_valids, true_length=true_length, temperature=jnp.array([sampling_params.temperature], "f4"), top_p=jnp.array([sampling_params.top_p], "f4"), rngs=prefill_engine.prng_key, ) request.prefill_result = prefill_result request.complete = np.zeros((prefill_engine.samples_per_slot,), "b1") my_detokenize_backlog = self._detokenize_backlogs[idx] request.metadata.transfer_enqueue_time = time.perf_counter() my_detokenize_backlog.put( (first_token, request, request.metadata.prefill_dequeue_time), block=True, ) my_transfer_backlog.put(request, block=True) logger.info( f"Placed request on transfer queue {idx}, {my_transfer_backlog.qsize()} queued requests.", ) del prefill_result del request def _jax_transfer_prefill_result(self, new_request: ActiveRequest, target_idx: int): """Transfers prefill result (KV cache) using JAX device placement. This method uses JAX's `jax.device_put` to transfer the prefill result (which typically contains the KV cache state after the prefill step) to the specified target decode engine's device, respecting its sharding configuration. It blocks until the transfer is complete. Args: new_request: The ActiveRequest containing the prefill_result. target_idx: The index of the target decode engine. """ new_request.prefill_result = jax.device_put( new_request.prefill_result, self._decode_engines[target_idx].get_prefix_destination_sharding(), ) jax.block_until_ready(new_request.prefill_result) def _ray_transfer_prefill_result(self, new_request: ActiveRequest, target_idx: int): """Transfers prefill result (KV cache) using Ray's transfer mechanism (if applicable). This method is a placeholder for potential future integration with Ray or other distributed computing frameworks that provide explicit data transfer mechanisms between workers or devices. It assumes the target decode engine has a `transfer` method. Args: new_request: The ActiveRequest containing the prefill_result. target_idx: The index of the target decode engine. """ # Assuming self._decode_engines[target_idx] has a 'transfer' method for Ray self._decode_engines[target_idx].transfer(new_request.prefill_result) def _transfer_prefill_result(self, new_request: ActiveRequest, target_idx: int): """Selects and executes the appropriate KV cache transfer method. This method acts as a dispatcher for transferring the prefill result (KV cache) from the prefill engine's device to the target decode engine's device. It currently defaults to using the JAX-specific transfer method but can be extended to support other frameworks like Ray by adding conditional logic based on the engine type or configuration. Args: new_request: The ActiveRequest containing the prefill_result. target_idx: The index of the target decode engine. """ self._jax_transfer_prefill_result(new_request, target_idx) def _transfer_thread(self, idx: int): """Transfers the kv cache on an active request to the least full generate backlog.""" transfer_backlog = self._transfer_backlogs[idx] while self.live: new_request = transfer_backlog.get(block=True) if new_request is None: break new_request.metadata.transfer_dequeue_time = time.perf_counter() target_idx = min(self._decode_backlogs.items(), key=lambda q: q[1].qsize())[0] if not self._interleaved_mode: logger.info( f"Transferring prefill from prefill engine {idx} to Decode engine {target_idx}." ) self._transfer_prefill_result(new_request, target_idx) new_request.metadata.generate_enqueue_time = time.perf_counter() self._decode_backlogs[target_idx].put(new_request, block=True) logger.info( "Successfully transferred prefill " f"from prefill engine {idx} to Decode engine {target_idx} " f"({self._decode_backlogs[target_idx].qsize()} requests now in backlog).", ) def _decode_thread(self, idx: int): """Step token generation and insert prefills from backlog.""" logger.info(f"Spinning up decode thread {idx}.") decode_engine = self._decode_engines[idx] my_slots = self._decode_slots[idx] my_decode_backlog = self._decode_backlogs[idx] my_detokenize_backlog = self._detokenize_backlogs[idx] generate_timestep = 0 decode_state = decode_engine.init_decode_state() time_of_last_print = time.time() while self.live: if (time.time() - time_of_last_print) > 5: logger.info( "Decode thread making a decision with:" f" prefill_backlog={self._prefill_backlog.qsize()}" f" generate_free_slots={my_slots.qsize()}", ) time_of_last_print = time.time() max_concurrent_decodes = decode_engine.max_concurrent_decodes while True: my_slots_size = my_slots.qsize() try: slot = my_slots.get(block=False) except queue.Empty: break block = my_slots_size == max_concurrent_decodes if self._interleaved_mode: block |= not self._prefill_backlog.empty() block |= not self._transfer_backlogs[idx].empty() try: new_request = my_decode_backlog.get(block=block, timeout=1.0) if new_request is None: break new_request.metadata.generate_dequeue_time = time.perf_counter() except queue.Empty: my_slots.put(slot, block=False) if block: continue else: break if new_request is None: return logger.info( f"Decode slice {idx} filling slot {slot} at step {generate_timestep}." ) decode_state = decode_engine.insert( prefix=new_request.prefill_result, decode_state=decode_state, slot=slot, ) del new_request.prefill_result new_request.generate_timestep_added = generate_timestep new_request.complete = np.zeros((decode_engine.samples_per_slot,), "b1") my_detokenize_backlog.put((slot, new_request), block=True) assert my_slots.qsize() < max_concurrent_decodes, ( "At this point we must have some requests inserted into the slots." ) time_of_last_decode = time.time() decode_state, sampled_tokens = decode_engine.decode( graphstate=decode_engine.graphstate, graphothers=decode_engine.graphothers, state=decode_state, rngs=decode_engine.prng_key, ) fn_call = time.time() sampled_tokens.copy_to_host_async() my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True) generate_timestep += 1 # TODO:Debug _took = (time.time() - time_of_last_decode) * 10**3 _exec = (fn_call - time_of_last_decode) * 10**3 logger.info( f"Decode engine {idx} step {generate_timestep} - slots free : {my_slots_size} / {max_concurrent_decodes}, " f"took {_took:.2f}ms | execution took {_exec:.2f}ms " ) def _detokenize_thread(self, idx: int): """Detokenize sampled tokens and returns them to the user.""" my_detokenize_backlog = self._detokenize_backlogs[idx] my_decode_engine = self._decode_engines[idx] my_slots = self._decode_slots[idx] processor = my_decode_engine.processor my_live_requests = {i: None for i in range(my_decode_engine.max_concurrent_decodes)} while self.live: data = my_detokenize_backlog.get(block=True) if data is None: break if isinstance(data[0], ResultTokens): # Handling the very first token from prefill request_first_token, request, _ = data request_first_token = request_first_token.convert_to_numpy() # Process the first token, but TPS/count are not meaningful yet results_base, complete, num_valid_tokens_list = process_result_tokens( processor=processor, slot=0, # Prefill result is always at slot 0 conceptually slot_max_length=request.max_tokens, result_tokens=request_first_token, eos_token_id=my_decode_engine.eos_token_ids, is_client_side_tokenization=request.is_client_side_tokenization, complete=request.complete, ) request.complete = complete # Add placeholder metrics for the first token final_results = [] for res_base, num_valid in zip(results_base, num_valid_tokens_list): # Start tracking total generated tokens from the first valid one request.total_generated_tokens += num_valid final_results.append( ReturnSample( text=res_base.text, token_ids=res_base.token_ids, tokens_per_second=0.0, # Not applicable yet num_generated_tokens=request.total_generated_tokens, ) ) request.enqueue_samples(final_results) first_token_return_time = ( time.perf_counter() - request.metadata.prefill_dequeue_time ) * 1000 logger.info(f"TTFT duration: {first_token_return_time}ms") elif isinstance(data[1], ResultTokens): # Handling subsequent decode steps generate_timestep_added, result_tokens = data result_tokens = result_tokens.convert_to_numpy() for slot, request in my_live_requests.items(): if request is not None: # Start timer on the first actual decode step for this request if request.decode_start_time is None: request.decode_start_time = time.perf_counter() results_base, complete, num_valid_tokens_list = process_result_tokens( processor=processor, slot=slot, slot_max_length=request.max_tokens, result_tokens=result_tokens, eos_token_id=my_decode_engine.eos_token_ids, is_client_side_tokenization=request.is_client_side_tokenization, complete=request.complete, ) request.complete = complete elapsed_time = time.perf_counter() - request.decode_start_time final_step_results = [] for res_base, num_valid in zip(results_base, num_valid_tokens_list): request.total_generated_tokens += num_valid tps = ( request.total_generated_tokens / elapsed_time if elapsed_time > 1e-6 # Avoid division by zero else 0.0 ) final_step_results.append( ReturnSample( text=res_base.text, token_ids=res_base.token_ids, tokens_per_second=tps, num_generated_tokens=request.total_generated_tokens, ) ) request.enqueue_samples(final_step_results) if request.complete.all(): request.metadata.complete_time = time.perf_counter() request.return_channel.close() my_live_requests[slot] = None my_slots.put(slot, block=False) my_decode_engine.free_resource(slot) else: slot, active_request = data my_live_requests[slot] = active_request