# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""eSurge Model Runner - High-performance inference execution engine.
This module implements the core execution logic for the eSurge inference engine,
providing efficient model execution with advanced features like paged attention,
dynamic batching, and compilation caching.
Key Components:
ExecutionManager: Manages compiled execution functions for different batch/token configurations
eSurgeRunner: Main runner class that orchestrates model execution
Architecture:
The module uses a two-stage compilation strategy:
1. Pre-compilation of functions for different token/batch size combinations
2. Runtime selection of appropriate compiled function based on input shape
Performance Features:
- Paged attention for efficient KV cache management
- Vectorized operations for batch processing
- Pre-allocated buffers to minimize memory allocation
- Compilation caching to avoid recompilation
- Progress logging for long compilation processes
Example:
>>> from easydel.infra import EasyDeLBaseModule
>>> from easydel.inference.esurge.runners import eSurgeRunner
>>>
>>> # Initialize model
>>> model = EasyDeLBaseModule.from_pretrained("model-name")
>>>
>>> # Create runner
>>> runner = eSurgeRunner(
... model=model,
... max_model_len=2048,
... max_num_seqs=8,
... hbm_utilization=0.9
... )
>>>
>>> # Compile for different configurations
>>> runner.compile()
>>>
>>> # Execute model
>>> output = runner.execute_model(scheduler_output)
"""
from __future__ import annotations
import time
import typing
from bisect import bisect_left
from concurrent.futures import Future
import flax
import jax
import numpy as np
from eformer.loggings import get_logger
from jax import numpy as jnp
from ..metrics import get_metrics_collector
from ..outputs import ModelRunnerOutput
from ..page_table import PAGE_TABLE_PADDING_VAL
from ..scheduler import SchedulerOutput
from .async_types import AsyncPreResults
from .execution_manager import ExecutionManager
from .sequence_buffer import (
SequenceBuffer,
build_allowed_mask,
build_sampling_arrays,
fill_slice,
move_row,
pack_prompts,
swap_rows,
)
from .states import CachedRequestState
if typing.TYPE_CHECKING:
from easydel.infra import EasyDeLBaseModule
logger = get_logger("eSurge")
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int, min_input_pad: int) -> int:
"""Calculate padded request count for compilation efficiency.
Pads the number of requests to powers of 2 (up to 8) or the nearest
power of 2 above 8. This reduces the number of unique compilations
needed while maintaining good utilization.
Args:
x: Actual number of requests
upper_limit: Maximum allowed requests
Returns:
int: Padded request count, capped at upper_limit
Example:
>>> _get_padded_num_reqs_with_upper_limit(3, 32) # Returns 8
>>> _get_padded_num_reqs_with_upper_limit(10, 32) # Returns 16
>>> _get_padded_num_reqs_with_upper_limit(20, 16) # Returns 16
"""
res = min_input_pad if x <= min_input_pad else 1 << (x - 1).bit_length()
return min(res, upper_limit)
[docs]class eSurgeRunner:
"""High-performance model runner for efficient batched inference.
The eSurgeRunner orchestrates model execution with advanced features:
- Paged attention for memory-efficient KV cache management
- Dynamic batching with request scheduling
- Pre-allocated buffers for zero-copy operations
- Vectorized token processing
- Compilation caching for different batch/sequence configurations
The runner maintains an internal state of active requests and manages
their lifecycle from prompt processing through token generation.
Architecture:
Request Flow:
1. Scheduler provides requests to execute
2. Runner updates internal state (add/remove requests)
3. Prepares inputs with proper padding and batching
4. Executes model using pre-compiled functions
5. Processes sampled tokens and updates buffers
6. Returns results to scheduler
Memory Management:
- Pre-allocated buffers for common operations
- Paged KV cache with configurable page size
- Efficient slot mapping for attention
- Buffer reuse across batches
Attributes:
model: The EasyDeL model to run
metadata: Paged attention metadata
max_num_seqs: Maximum concurrent sequences
max_model_len: Maximum sequence length
executor_manager: Manages compiled functions
sequence_buffer: Manages active sequences
requests: Active request states
Example:
>>> runner = eSurgeRunner(
... model=model,
... max_model_len=2048,
... max_num_seqs=8,
... hbm_utilization=0.9,
... page_size=128
... )
>>>
>>> # Compile for all configurations
>>> runner.compile()
>>>
>>> # Execute requests from scheduler
>>> output = runner.execute_model(scheduler_output)
>>>
>>> # Process results
>>> for req_id, tokens in zip(output.req_ids, output.sampled_token_ids):
... print(f"Request {req_id}: {tokens}")
"""
def __init__(
self,
model: EasyDeLBaseModule,
hbm_utilization: float = 0.5,
page_size: int = 128,
max_model_len: int = 2**13,
min_input_pad: int = 256,
max_num_seqs: int = 16,
max_num_seq_buckets: list[int] | None = None,
use_aot_forward: bool = True,
verbose: bool = False,
enable_overlap_execution: bool = False,
enable_sampler_metrics: bool = False,
):
logger.debug(f"Initializing eSurgeRunner with {max_model_len=}, {max_num_seqs=}")
logger.debug(f"Configuration: {hbm_utilization=}, {page_size=}")
self.model = model.esurge_compatible_model
self.metadata = model.create_paged_metadata(
hbm_utilization=hbm_utilization,
page_size=page_size,
max_model_length=max_model_len,
)
self.max_num_seq_buckets = self._init_seq_buckets(max_num_seq_buckets, max_num_seqs, min_input_pad)
self.max_num_seqs = max_num_seqs
self.max_num_reqs = self.max_num_seq_buckets[-1]
self.max_model_len = max_model_len
self.min_input_pad = max(min_input_pad, self.max_num_seq_buckets[0])
self.page_size = int(self.metadata.page_size)
self.max_pages_per_req = int(self.metadata.max_num_pages_per_req)
self.num_tokens_paddings = self._get_token_paddings(
min_token_size=self.min_input_pad,
max_token_size=self.max_model_len,
padding_gap=0,
)
self.max_num_tokens = self.num_tokens_paddings[-1]
logger.debug("Creating ExecutionManager and initializing pages cache")
self.executor_manager = ExecutionManager(
model=model,
use_aot_forward=use_aot_forward,
min_input_pad=self.min_input_pad,
max_model_len=max_model_len,
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
metadata=self.metadata,
verbose=verbose,
)
self.log_it = logger.info if verbose else logger.debug
self._setup_variables()
self.enable_overlap_execution = enable_overlap_execution
self.enable_sampler_metrics = enable_sampler_metrics
# Async scheduling state
self._pre_async_results: AsyncPreResults | None = None
self._executor: typing.Any = None # ThreadPoolExecutor, typed as Any to avoid circular import
logger.debug("eSurgeRunner initialization complete")
@property
def mesh(self):
return self.model.mesh
@property
def _empty_sharding(self):
return jax.NamedSharding(self.mesh, jax.sharding.PartitionSpec())
@staticmethod
def _get_token_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]:
"""Generate padding sizes for efficient compilation.
Args:
min_token_size: Minimum token size (must be power of 2)
max_token_size: Maximum token size to cover
padding_gap: Gap between padding sizes (0 for exponential growth)
Returns:
List of padding sizes
"""
if not ((min_token_size & (min_token_size - 1) == 0) and min_token_size > 0):
logger.error(f"Invalid min_token_size={min_token_size}, must be power of 2")
raise ValueError(f"min_token_size must be a power of 2, got {min_token_size}")
paddings = []
num = min_token_size
if padding_gap == 0:
while num <= max_token_size:
paddings.append(num)
num *= 2
else:
while num <= padding_gap:
paddings.append(num)
num *= 2
num //= 2
while num < max_token_size:
num += padding_gap
paddings.append(num)
if paddings[-1] != max_token_size:
paddings.append(max_token_size)
return paddings
@staticmethod
def _get_request_paddings(min_bucket: int, max_bucket: int) -> list[int]:
min_bucket = max(1, min(min_bucket, max_bucket))
buckets: list[int] = []
current = min_bucket
while current < max_bucket:
buckets.append(current)
current *= 2
if not buckets or buckets[-1] != max_bucket:
buckets.append(max_bucket)
return buckets
def _init_seq_buckets(
self,
user_buckets: list[int] | None,
max_num_seqs: int,
min_input_pad: int,
) -> list[int]:
if user_buckets:
buckets = sorted({int(b) for b in user_buckets if 0 < int(b) <= max_num_seqs})
else:
buckets = self._get_request_paddings(min_input_pad, max_num_seqs)
if not buckets or buckets[-1] != max_num_seqs:
buckets.append(max_num_seqs)
return buckets
def _get_current_bucket(self, num_reqs: int) -> int:
"""Select the smallest bucket that can accommodate num_reqs.
Args:
num_reqs: Number of active requests
Returns:
Smallest sufficient bucket size from self.max_num_seq_buckets
"""
if num_reqs <= 0:
return self.max_num_seq_buckets[0]
for bucket in self.max_num_seq_buckets:
if num_reqs <= bucket:
return bucket
return self.max_num_seq_buckets[-1]
def _setup_variables(self):
"""Initialize internal variables and preallocate reusable buffers."""
self.num_reqs_max_model_len = min(self.metadata.get_max_num_seqs(), self.max_num_reqs)
self.num_reqs_most_model_len = self.num_reqs_max_model_len
self.requests: dict[str, CachedRequestState] = {}
logger.debug(f"Token padding sizes: {len(self.num_tokens_paddings)} levels, max={self.max_num_tokens}")
logger.debug(
f"Creating sequence buffer for max_num_reqs={self.max_num_reqs}, max_model_len={self.max_model_len}"
)
self.sequence_buffer = SequenceBuffer(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.model.config.get_text_config().vocab_size,
page_sizes=[self.metadata.page_size],
sharding=self._empty_sharding,
)
self.arange = jnp.arange(self.max_num_tokens, dtype=jnp.int32)
self.arange_np = jnp.arange(self.max_num_reqs, dtype=jnp.int32)
self.input_ids_buf = jnp.zeros((self.max_num_tokens,), dtype=jnp.int32, device=self._empty_sharding)
self.position_ids_buf = jnp.zeros((self.max_num_tokens,), dtype=jnp.int32, device=self._empty_sharding)
self.query_start_loc_buf = jnp.zeros((self.max_num_reqs + 1,), dtype=jnp.int32, device=self._empty_sharding)
self.seq_lens_buf = jnp.zeros((self.max_num_reqs,), dtype=jnp.int32, device=self._empty_sharding)
self.pages_tables_buf = jnp.full(
(self.num_reqs_max_model_len, self.max_pages_per_req),
fill_value=PAGE_TABLE_PADDING_VAL,
dtype=jnp.int32,
device=self._empty_sharding,
)
self.num_tokens_paddings_arr = jnp.array(self.num_tokens_paddings, dtype=jnp.int32, device=self._empty_sharding)
self.scheduled_full_buf = jnp.zeros((self.max_num_reqs,), dtype=jnp.int32, device=self._empty_sharding)
self.req_num_tokens_full_buf = jnp.zeros((self.max_num_reqs,), dtype=jnp.int32, device=self._empty_sharding)
self.active_mask_full_buf = jnp.zeros((self.max_num_reqs,), dtype=bool, device=self._empty_sharding)
def _precompile_jitted_helpers(
self,
reqs_padds: list[int],
prompt_len_buckets: list[int],
precompile_allowed_mask: bool = False,
allowed_max: int = 512,
) -> None:
logger.info("Precompiling eSurgeRunner helper kernels")
B = self.max_num_reqs
T = self.max_model_len
V = int(self.model.config.get_text_config().vocab_size)
token_ids = jnp.zeros((B, T), dtype=jnp.int32)
num_prompt_tokens = jnp.zeros((B,), dtype=jnp.int32)
temperature = jnp.zeros((B,), dtype=jnp.float32)
min_p = jnp.zeros((B,), dtype=jnp.float32)
top_p = jnp.ones((B,), dtype=jnp.float32)
top_k = jnp.zeros((B,), dtype=jnp.int32)
for pr_len in prompt_len_buckets:
pr_len = min(pr_len, self.max_model_len)
for pr_reqs in reqs_padds:
try:
lowered = pack_prompts.lower(
token_ids,
num_prompt_tokens,
padded_num_reqs=pr_reqs,
padded_prompt_len=pr_len,
pad_id=V,
)
_ = lowered.compile()
logger.debug(f"pack_prompts compiled for (padded_num_reqs={pr_reqs}, padded_prompt_len={pr_len})")
except Exception as e:
logger.debug(f"pack_prompts skip ({pr_reqs}, {pr_len}): {e}")
for pr_reqs in reqs_padds:
try:
lowered = build_sampling_arrays.lower(
temperature,
min_p,
top_p,
top_k,
jnp.int32(min(pr_reqs, B)), # num_reqs <= padded_num_reqs
padded_num_reqs=pr_reqs,
)
_ = lowered.compile()
logger.debug(f"build_sampling_arrays compiled for (padded_num_reqs={pr_reqs})")
except Exception as e:
logger.debug(f"build_sampling_arrays skip ({pr_reqs}): {e}")
for pr_reqs in reqs_padds:
try:
lowered = fill_slice.lower(
temperature,
jnp.float32(0.0),
int(pr_reqs),
int(pr_reqs),
)
_ = lowered.compile()
logger.debug(f"fill_slice compiled for (num_reqs={pr_reqs}, padded_num_reqs={pr_reqs})")
except Exception as e:
logger.debug(f"fill_slice skip ({pr_reqs}): {e}")
try:
_ = swap_rows.lower(token_ids, jnp.int32(0), jnp.int32(1)).compile()
_ = move_row.lower(token_ids, jnp.int32(0), jnp.int32(1)).compile()
logger.debug("swap_rows and move_row compiled")
except Exception as e:
logger.debug(f"swap_rows/move_row skip: {e}")
if precompile_allowed_mask:
max_allowed = int(min(allowed_max, V))
allowed_ids_padded = jnp.zeros((B, max_allowed), dtype=jnp.int32)
allowed_lens = jnp.zeros((B,), dtype=jnp.int32)
try:
lowered = build_allowed_mask.lower(
allowed_ids_padded,
allowed_lens,
vocab_size=int(V),
max_allowed=max_allowed,
)
_ = lowered.compile()
logger.debug(f"build_allowed_mask compiled for (B={B}, V={V}, max_allowed={max_allowed})")
except Exception as e:
logger.debug(f"build_allowed_mask skip (V={V}, max_allowed={max_allowed}): {e}")
logger.info("Helper kernel precompilation finished")
[docs] def compile(self):
"""Compile the model for all token padding sizes."""
logger.info("Starting eSurgeRunner compilation")
logger.debug(
f"Compiling for {len(self.num_tokens_paddings)} token padding sizes: {self.num_tokens_paddings[:5]}..."
if len(self.num_tokens_paddings) > 5
else f"Compiling for token padding sizes: {self.num_tokens_paddings}"
)
self.executor_manager.compile(
num_tokens_paddings=self.num_tokens_paddings,
num_reqs_max_model_len=self.num_reqs_max_model_len,
max_pages_per_req=self.max_pages_per_req,
max_num_reqs=self.max_num_reqs,
metadata=self.metadata,
num_reqs_paddings=self.max_num_seq_buckets,
)
self._precompile_jitted_helpers(
reqs_padds=self.max_num_seq_buckets,
prompt_len_buckets=[min(n, self.max_model_len) for n in self.num_tokens_paddings],
precompile_allowed_mask=False,
allowed_max=4096,
)
[docs] def update_model_weights(
self,
model: EasyDeLBaseModule | None = None,
*,
graphdef=None,
graphstate=None,
graphother=None,
reset_state: bool = True,
) -> None:
"""Update the runner's model weights/graphs and optionally reset state.
Args:
model: Optional EasyDeL model instance providing new weights. If
omitted, graph components must be supplied explicitly.
graphdef: Optional graphdef override.
graphstate: Optional graphstate override.
graphother: Optional graphother override.
reset_state: When True (default) reinitializes internal buffers and
cached requests to ensure the new weights are applied cleanly.
Raises:
RuntimeError: If active requests exist while reset_state is True.
"""
if reset_state and self.requests:
raise RuntimeError("Cannot update model weights while requests are active")
if model is None:
assert graphdef is not None
assert graphstate is not None
assert graphother is not None
model = flax.nnx.merge(graphdef, graphstate, graphother)
model = model.esurge_compatible_model
graphdef = model.graphdef
self.model = model
self.executor_manager.update_graphs(
model=model,
graphdef=graphdef,
graphstate=graphstate,
graphother=graphother,
)
if reset_state:
self._setup_variables()
[docs] def destroy_kv_cache(self) -> None:
"""Destroy the current ragged KV cache to release memory."""
logger.info("Destroying eSurgeRunner ragged KV cache pages")
self.executor_manager.kv_pages = None
[docs] def initialize_kv_cache(self) -> None:
"""Reinitialize the ragged KV cache if it has been destroyed."""
if self.executor_manager.kv_pages is not None:
logger.debug("KV cache already initialized; skipping reallocation")
return
logger.info("Reinitializing eSurgeRunner ragged KV cache pages")
self.executor_manager.kv_pages = self.model.init_ragged_pages(self.metadata)
def _update_states(self, scheduler_output: SchedulerOutput) -> bool:
"""Update internal states based on scheduler output.
Synchronizes the runner's internal state with the scheduler's decisions.
Handles request lifecycle: adding new requests, removing finished ones,
updating cached requests, and managing the sequence buffer.
State Updates:
1. Remove finished requests from tracking
2. Remove unscheduled requests from buffer
3. Add new requests with their metadata
4. Update cached request states
5. Reorganize sequence buffer for efficiency
Args:
scheduler_output: Contains request scheduling decisions including:
- finished_req_ids: Requests that completed
- scheduled_new_reqs: New requests to add
- scheduled_cached_reqs: Existing requests to update
- num_scheduled_tokens: Tokens to generate per request
Returns:
True if state changed (requests added/removed), indicating
potential buffer reorganization. False if no changes occurred.
Side Effects:
- Updates self.requests dictionary
- Modifies sequence buffer contents
- May trigger buffer condensation
Note:
This method is called at the beginning of each execution cycle
to ensure the runner's state matches the scheduler's decisions.
"""
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# 2) Remove finished from sequence buffer (functional)
removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.sequence_buffer.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# 3) Remove unscheduled requests from buffer
scheduled_req_ids = set(scheduler_output.num_scheduled_tokens.keys())
cached_req_ids = set(self.sequence_buffer.req_id_to_index.keys())
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
for req_id in unscheduled_req_ids:
req_index = self.sequence_buffer.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# 4) Add new requests to tracking
req_ids_to_add: list[str] = []
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.sampling_params is not None, "Pooling not supported in TPU"
req_id = new_req_data.req_id
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
sampling_params=new_req_data.sampling_params,
generator=None,
page_ids=new_req_data.page_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# 5) Update cached requests and page tables
req_data = scheduler_output.scheduled_cached_reqs
upd_req_indices: list[int] = []
upd_num_computed_vals: list[int] = []
batched_page_rows: list[tuple[int, tuple[list[int], ...]]] = []
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests.get(req_id)
if req_state is None:
continue
nct = req_data.num_computed_tokens[i]
new_page_ids = req_data.new_page_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
req_state.num_computed_tokens = nct
if not resumed_from_preemption:
for page_ids, new_ids in zip(req_state.page_ids, new_page_ids, strict=False):
page_ids.extend(new_ids)
else:
req_state.page_ids = new_page_ids
req_index = self.sequence_buffer.req_id_to_index.get(req_id)
if req_index is None:
req_ids_to_add.append(req_id)
continue
upd_req_indices.append(req_index)
upd_num_computed_vals.append(int(nct))
batched_page_rows.append((req_index, new_page_ids))
if upd_req_indices:
# num_computed_tokens is now a NumPy array, use standard indexing
idx_arr = np.array(upd_req_indices, dtype=np.int32)
val_arr = np.array(upd_num_computed_vals, dtype=np.int32)
new_num_computed = self.sequence_buffer.num_computed_tokens.copy()
new_num_computed[idx_arr] = val_arr
self.sequence_buffer.num_computed_tokens = new_num_computed
if batched_page_rows:
indices = [ix for ix, _ in batched_page_rows]
pages_per_req = [ids for _, ids in batched_page_rows]
self.sequence_buffer.page_table.append_rows_batch(pages_per_req, indices)
self.sequence_buffer.page_table.commit(self.sequence_buffer.num_reqs)
# 6) Add new / reinserted requests
# Sort in reverse order and pop() to get highest indices first (avoid reusing index 0/1)
# This prevents KV cache corruption from repeatedly reusing low indices
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
# Pop() from reverse-sorted list gives highest index first
reuse_index = removed_req_indices.pop() if removed_req_indices else None
self.sequence_buffer.add_request(req_state, reuse_index)
# 7) Condense to remove holes
if removed_req_indices:
self.sequence_buffer.condense(removed_req_indices)
has_changes = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
return has_changes
def _modify_prev_results(self) -> None:
"""Apply previous iteration's tokens to sequence buffer.
This method is called at the beginning of each iteration when async
scheduling is enabled. It retrieves the tokens that were sampled
asynchronously in the previous iteration and applies them to the
sequence buffer.
The method blocks until the async token transfer is complete, then
updates the token_ids array and request output_token_ids lists.
Note:
This method should only be called when self._pre_async_results is not None.
"""
if self._pre_async_results is None:
return
pre_req_ids = self._pre_async_results.req_ids
pre_next_tokens = self._pre_async_results.next_tokens
pre_request_seq_lens = self._pre_async_results.request_seq_lens
pre_discard_indices = self._pre_async_results.discard_sampled_tokens_req_indices
# Block until tokens are ready (async copy to host completes)
next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
selected_token_ids = np.expand_dims(next_tokens_cpu[: len(pre_req_ids)], 1)
# Mask out discarded tokens
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
for i in pre_discard_indices:
valid_sampled_token_ids[i] = np.array([])
# Apply tokens to sequence buffer
for pre_req_idx, req_state, _ in pre_request_seq_lens:
sampled_ids = valid_sampled_token_ids[pre_req_idx]
if len(sampled_ids) == 0:
continue
# Check if request is still active
req_id = pre_req_ids[pre_req_idx]
if req_id not in self.sequence_buffer.req_id_to_index:
continue
req_idx = self.sequence_buffer.req_id_to_index[req_id]
assert req_state is self.requests[req_id], "Request state mismatch"
# Update token_ids array (replace placeholder)
end_idx = self.sequence_buffer.num_tokens_no_spec[req_idx]
start_idx = end_idx - 1
assert end_idx <= self.max_model_len, f"Token count {end_idx} exceeds max_model_len {self.max_model_len}"
self.sequence_buffer.token_ids[req_idx, start_idx:end_idx] = sampled_ids
# Replace placeholder in output_token_ids
req_state.output_token_ids[-1] = int(sampled_ids[-1])
def _update_placeholder(
self,
discard_sampled_tokens_req_indices: list[int],
request_seq_lens: list[tuple[int, CachedRequestState, int]],
) -> dict[str, int]:
"""Set placeholders for tokens not yet generated.
When async scheduling is enabled, this method is called after the
forward pass to set placeholder tokens (0) for requests that will
generate tokens. The actual tokens will be filled in during the
next iteration via _modify_prev_results().
Args:
discard_sampled_tokens_req_indices: Indices of requests whose
tokens should be discarded (e.g., partial prefill).
request_seq_lens: List of (req_idx, req_state, seq_len) tuples
for requests that generated tokens.
Returns:
Mapping from request ID to index for placeholder replacement.
Note:
This method updates num_tokens_no_spec and num_tokens in the
sequence buffer, and appends placeholder (0) to output_token_ids.
"""
placeholder_req_id_to_index: dict[str, int] = {}
discard_set = set(discard_sampled_tokens_req_indices)
for req_idx, req_state, _ in request_seq_lens:
if req_idx in discard_set:
continue
start_idx = self.sequence_buffer.num_tokens_no_spec[req_idx]
end_idx = start_idx + 1 # Assume 1 token (no spec decode yet)
assert end_idx <= self.max_model_len, (
f"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: {self.max_model_len}"
)
# Update buffer state
self.sequence_buffer.num_tokens_no_spec[req_idx] = end_idx
self.sequence_buffer.num_tokens[req_idx] = end_idx
# Add placeholder (0) to output
req_state.output_token_ids.extend([0])
placeholder_req_id_to_index[req_state.req_id] = req_idx
return placeholder_req_id_to_index
def _reorder_decode_first(self, scheduler_output: SchedulerOutput) -> None:
"""Reorder active requests so decode tokens are placed first."""
i, j = 0, self.sequence_buffer.num_reqs - 1
while i < j:
i_req_id = self.sequence_buffer.req_ids[i]
j_req_id = self.sequence_buffer.req_ids[j]
if i_req_id is None or j_req_id is None:
break
i_is_decode = (
scheduler_output.num_scheduled_tokens.get(i_req_id, 0) == 1
and self.sequence_buffer.num_computed_tokens[i] > 0
)
j_is_decode = (
scheduler_output.num_scheduled_tokens.get(j_req_id, 0) == 1
and self.sequence_buffer.num_computed_tokens[j] > 0
)
if i_is_decode:
i += 1
elif not j_is_decode:
j -= 1
else:
# Swap to move a decode request forward.
self.sequence_buffer.swap_states(i, j)
i += 1
j -= 1
def _execute_model_impl(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
"""Execute the model on scheduled requests.
Main entry point for model execution. Processes all scheduled requests
in batches, handling state updates, input preparation, model execution,
and token processing.
The method handles:
1. State synchronization with scheduler
2. Batch-wise processing of requests
3. Token generation and sampling
4. Buffer updates and metrics logging
Args:
scheduler_output: Output from the scheduler containing:
- Requests to process
- Tokens to generate per request
- Finished/new/cached request information
Returns:
ModelRunnerOutput: Contains:
- req_ids: List of processed request IDs
- sampled_token_ids: Generated tokens per request
- logprobs: Log probabilities (if requested)
- Timing and debugging information
Note:
The method processes requests in batches when they exceed
the maximum model length, ensuring all requests are handled
efficiently without exceeding memory constraints.
"""
execution_start_time = time.time()
updating_states_start = time.time()
self._update_states(scheduler_output)
updating_states_time = time.time() - updating_states_start
# Apply previous async results if available
if self._pre_async_results is not None:
self._modify_prev_results()
self._pre_async_results = None # Clear after applying
# Align ordering with TPU runner: decode requests first.
if self.sequence_buffer.num_reqs > 1:
self._reorder_decode_first(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
return ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=None,
finished_recving=None,
num_nans_in_logits=None,
)
start_index = 0
total_step_time = 0.0
total_sync_time = 0.0
total_post_proc_time = 0.0
req_ids_all: list[str] = []
sampled_token_ids_all: list[list[int]] = []
token_logprobs: dict[str, float] = {}
while start_index < self.sequence_buffer.num_reqs:
num_reqs_total = self.sequence_buffer.num_reqs
scheduled_list: list[int] = []
req_ids_window = []
for i in range(start_index, min(num_reqs_total, start_index + self.num_reqs_max_model_len)):
rid = self.sequence_buffer.req_ids[i]
req_ids_window.append(rid)
scheduled_list.append(int(scheduler_output.num_scheduled_tokens.get(rid, 0)) if rid is not None else 0)
while scheduled_list and scheduled_list[-1] == 0:
scheduled_list.pop()
req_ids_window.pop()
num_reqs = len(scheduled_list)
if num_reqs == 0:
break
end_index = start_index + num_reqs
total_scheduled = sum(scheduled_list)
idx = bisect_left(self.num_tokens_paddings, total_scheduled)
if idx >= len(self.num_tokens_paddings):
idx = len(self.num_tokens_paddings) - 1
num_tokens_static = int(self.num_tokens_paddings[idx])
# Select optimal bucket for current batch size
# This determines which compiled function to use
current_bucket = self._get_current_bucket(num_reqs)
padded_num_reqs = current_bucket # Use bucket size for compilation lookup
if num_reqs > 0:
# Keep scheduled and active_mask as CPU arrays
scheduled_full_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
scheduled_full_cpu[: len(scheduled_list)] = scheduled_list
req_num_tokens_np = np.zeros(self.max_num_reqs, dtype=np.int32)
active_mask_full_cpu = np.zeros(self.max_num_reqs, dtype=bool)
for i, rid in enumerate(req_ids_window):
if rid is not None:
rs = self.requests.get(rid)
if rs:
req_num_tokens_np[i] = rs.num_tokens
active_mask_full_cpu[i] = True
self.req_num_tokens_full_buf = jax.device_put(
jnp.asarray(req_num_tokens_np, dtype=jnp.int32),
self._empty_sharding,
)
# Get page table as CPU array (already on CPU, no transfer needed)
page_table_cpu = self.sequence_buffer.page_table[0].get_cpu_tensor()
step_start = time.time()
(
minimal_device_state,
out_tokens_win,
valid_mask_win,
self.input_ids_buf,
self.position_ids_buf,
self.query_start_loc_buf,
self.seq_lens_buf,
self.pages_tables_buf,
_hidden_states,
_logits,
metrics,
) = self.executor_manager.execute(
num_tokens=num_tokens_static,
scheduled_full_cpu=scheduled_full_cpu,
req_num_tokens_full=self.req_num_tokens_full_buf,
active_mask_full_cpu=active_mask_full_cpu,
input_ids_buf=self.input_ids_buf,
position_ids_buf=self.position_ids_buf,
padded_num_reqs=padded_num_reqs,
# Pass NumPy arrays for CPU-first batch metadata preparation
token_ids_cpu=self.sequence_buffer.token_ids,
num_computed_tokens_cpu=self.sequence_buffer.num_computed_tokens,
temperature_cpu=self.sequence_buffer.temperature,
top_p_cpu=self.sequence_buffer.top_p,
top_k_cpu=self.sequence_buffer.top_k,
min_p_cpu=self.sequence_buffer.min_p,
page_table_cpu=page_table_cpu,
)
# Start async copy to host (non-blocking) - overlaps with post-processing
token_ids_async = jax.copy_to_host_async(minimal_device_state.token_ids)
num_tokens_async = jax.copy_to_host_async(minimal_device_state.num_tokens)
# account for device time (blocking already happened inside execute())
total_step_time += time.time() - step_start
# host copies once
tokens_np = np.asarray(out_tokens_win)
valid_np = np.asarray(valid_mask_win)
logits_np = np.asarray(_logits) if self.enable_sampler_metrics and _logits is not None else None
# Track for async scheduling
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices: list[int] = []
up_wtime = time.time()
for i, rid in enumerate(req_ids_window):
if rid is None:
continue
req_ids_all.append(rid)
if valid_np[i]:
tid = int(tokens_np[i])
# Get request state and sequence length
if rid in self.requests:
req_state = self.requests[rid]
seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens.get(rid, 0)
# Check if async scheduling is enabled
if scheduler_output.async_scheduling:
# Async mode: don't append yet, will be done in next iteration
request_seq_lens.append((i, req_state, seq_len))
else:
# Sync mode: append immediately
sampled_token_ids_all.append([tid])
req_state.output_token_ids.append(tid)
else:
# No request state, append in sync mode
sampled_token_ids_all.append([tid])
if self.enable_sampler_metrics and logits_np is not None and i < logits_np.shape[0]:
try:
token_logprobs[rid] = logits_np[i]
except Exception:
pass
else:
sampled_token_ids_all.append([])
discard_sampled_tokens_req_indices.append(i)
up_wtime_took = time.time() - up_wtime
total_post_proc_time += up_wtime_took
# Complete async sync-back (should be done by now, overlapped with post-processing)
sq_utime = time.time()
token_ids_updated = np.asarray(token_ids_async)
num_tokens_updated = np.asarray(num_tokens_async)
if not token_ids_updated.flags.writeable:
token_ids_updated = token_ids_updated.copy()
if not num_tokens_updated.flags.writeable:
num_tokens_updated = num_tokens_updated.copy()
self.sequence_buffer.token_ids = token_ids_updated
self.sequence_buffer.num_computed_tokens = num_tokens_updated
sq_utime_took = time.time() - sq_utime
total_sync_time += sq_utime_took
start_index = end_index
metrics_collector = get_metrics_collector()
if metrics_collector:
metrics_collector.record_runner_metrics(
execution_time=time.time() - execution_start_time,
batch_size=len(req_ids_all),
num_tokens=scheduler_output.total_num_scheduled_tokens,
)
total_time = time.time() - execution_start_time
exec_took = metrics["exec_time"]
sample_took = metrics["sample_time"]
prep_took = metrics["prep_time"]
buckets_processed = metrics["buckets_processed"]
self.log_it(
f"[execute] "
f"step={total_step_time:.3f}s "
f"fwd={exec_took:.3f}s "
f"sample={sample_took:.3f}s "
f"prep={prep_took:.3f}s "
f"p-bs={buckets_processed}s "
f"update={updating_states_time:.3f}s "
f"total={total_time:.3f}s"
)
# Handle async scheduling return
if scheduler_output.async_scheduling:
# Set placeholders for current batch
placeholder_req_id_to_index = self._update_placeholder(
discard_sampled_tokens_req_indices,
request_seq_lens,
)
# Async copy to host (non-blocking)
next_tokens_jax = jnp.array(tokens_np, dtype=jnp.int32)
next_tokens = jax.copy_to_host_async(next_tokens_jax)
# Store async results for next iteration
self._pre_async_results = AsyncPreResults(
req_ids=req_ids_all,
next_tokens=next_tokens,
request_seq_lens=request_seq_lens,
discard_sampled_tokens_req_indices=discard_sampled_tokens_req_indices,
placeholder_req_id_to_index=placeholder_req_id_to_index,
)
# Return immediately (non-blocking)
req_id_to_out_index = {rid: i for i, rid in enumerate(req_ids_all)}
return ModelRunnerOutput(
req_ids=req_ids_all,
req_id_to_index=req_id_to_out_index,
sampled_token_ids=[], # Empty, will be filled in next iteration
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={rid: None for rid in req_ids_all},
finished_sending=None,
finished_recving=None,
token_logprobs=token_logprobs or None,
)
# Stable mapping for scheduler indexing
req_id_to_out_index = {rid: i for i, rid in enumerate(req_ids_all)}
return ModelRunnerOutput(
req_ids=req_ids_all,
req_id_to_index=req_id_to_out_index,
sampled_token_ids=sampled_token_ids_all,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={rid: None for rid in req_ids_all},
finished_sending=None,
finished_recving=None,
token_logprobs=token_logprobs or None,
)
[docs] def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
return self._execute_model_impl(scheduler_output)
[docs] def execute_model_async(self, scheduler_output: SchedulerOutput) -> Future[ModelRunnerOutput]:
"""Execute model asynchronously in a background thread.
This method enables async scheduling by executing the model in a separate
thread, allowing the caller to continue scheduling the next batch while
the current batch is being processed.
The async execution workflow:
1. Submit model execution to thread pool executor
2. Return immediately with a Future object
3. Caller can schedule next batch while this executes
4. Use wait_for_execution(future) to get results when needed
Args:
scheduler_output: Scheduling decisions for this iteration
Returns:
Future[ModelRunnerOutput]: Future that will contain the model output
when execution completes. Can be waited on using wait_for_execution().
Raises:
RuntimeError: If async execution is not enabled (executor not initialized)
Note:
This method requires async scheduling to be enabled and the executor
to be initialized. Initialize the executor by calling
initialize_async_executor() first.
Example:
>>> # Initialize async executor first
>>> runner.initialize_async_executor()
>>>
>>> # Execute asynchronously
>>> future = runner.execute_model_async(scheduler_output)
>>>
>>> # Do other work while model executes...
>>> next_schedule = scheduler.schedule()
>>>
>>> # Wait for current execution to finish
>>> output = runner.wait_for_execution(future)
"""
if self._executor is None:
raise RuntimeError(
"Async execution not enabled. Call initialize_async_executor() first "
"or check that async_scheduling is enabled in scheduler config."
)
return self._executor.submit(self._execute_model_impl, scheduler_output)
[docs] def initialize_async_executor(self) -> None:
"""Initialize the thread pool executor for async model execution.
This method creates a single-threaded executor that will be used to
run model execution in the background, enabling async scheduling.
Side Effects:
- Creates self._executor as a ThreadPoolExecutor with 1 worker
- Existing executor is shutdown if present
Note:
This should be called before using execute_model_async().
The executor uses a single worker to maintain execution order.
"""
if self._executor is not None:
logger.debug("Shutting down existing executor before reinitializing")
self._executor.shutdown(wait=True)
from concurrent.futures import ThreadPoolExecutor
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="eSurgeAsync")
logger.debug("Initialized async executor for model execution")
[docs] def reset_state(self) -> None:
"""Clear sequence state and request bookkeeping.
Useful when pausing or resetting the runner to ensure no stale pages
or request metadata linger between sessions.
"""
self.requests.clear()
self.sequence_buffer.clear()
self._pre_async_results = None
[docs] def wait_for_execution(self, future: Future) -> ModelRunnerOutput:
"""Wait for an async execution to complete and return the result.
Args:
future: The Future object returned by execute_model_async()
Returns:
ModelRunnerOutput: The completed model execution output
Note:
This call blocks until the future completes.
"""
return future.result()
[docs] def shutdown(self) -> None:
"""Cleanup resources including async executor if present."""
if self._executor is not None:
logger.debug("Shutting down async executor")
self._executor.shutdown(wait=True)
self._executor = None