# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence buffer for managing token sequences during generation.
Provides efficient storage and management of token sequences, page tables,
and generation metadata for batch processing.
Classes:
SequenceBuffer: Main buffer for managing sequences
DecodeRowInfo: Information about sequences in decode phase
ModelRunBatch: Batch data for model execution
Example:
>>> buffer = SequenceBuffer(
... max_num_reqs=32,
... max_model_len=2048,
... max_num_batched_tokens=4096,
... vocab_size=50000,
... page_sizes=[16, 32]
... )
>>> buffer.begin_sequence("req_1", [1, 2, 3])
>>> batch = buffer.prepare_batch()
"""
from __future__ import annotations
from typing import Any, cast
import jax
import numpy as np
from eformer.loggings import get_logger
from eformer.pytree import auto_pytree, field
from jax import numpy as jnp
from easydel.inference.esurge.request import EngineRequest
from easydel.utils.compiling_utils import ejit
from ...sampling_params import SamplingParams, SamplingType
from ..outputs import LogprobsTensors, swap_dict_values
from ..page_table import MultiGroupPageTable
logger = get_logger(__name__)
[docs]@ejit(static_argnames=("padded_num_reqs", "padded_prompt_len"))
def pack_prompts(token_ids, num_prompt_tokens, padded_num_reqs, padded_prompt_len, pad_id):
"""Pack prompt tokens into a padded tensor.
Creates a padded tensor of prompt tokens with consistent shape for
batch processing. Tokens beyond the prompt length are replaced with
the padding ID.
Args:
token_ids: Token IDs array of shape [max_num_reqs, max_model_len].
num_prompt_tokens: Number of prompt tokens per request [max_num_reqs].
padded_num_reqs: Target number of requests after padding.
padded_prompt_len: Target prompt length after padding.
pad_id: Token ID to use for padding.
Returns:
Packed prompts array of shape [padded_num_reqs, padded_prompt_len]
with valid tokens and padding.
Note:
This function is JIT-compiled with static arguments for padded dimensions
to enable efficient compilation caching.
"""
slice_tokens = token_ids[:padded_num_reqs, :padded_prompt_len]
lengths = num_prompt_tokens[:padded_num_reqs, None] # [B,1]
arange = jnp.arange(padded_prompt_len, dtype=lengths.dtype)[None, :] # [1,T]
mask = arange < lengths # [B,T]
pad_mat = jnp.full_like(slice_tokens, pad_id)
return jnp.where(mask, slice_tokens, pad_mat)
[docs]@ejit(static_argnames=("padded_num_reqs",))
def build_sampling_arrays(temperature, min_p, top_p, top_k, num_reqs, padded_num_reqs):
"""Build padded sampling parameter arrays.
Pads sampling parameters to a consistent size for batch processing,
filling unused slots with default values.
Args:
temperature: Temperature values for sampling.
min_p: Minimum probability threshold values.
top_p: Top-p (nucleus) sampling values.
top_k: Top-k sampling values.
num_reqs: Actual number of requests.
padded_num_reqs: Target padded number of requests.
Returns:
A tuple of padded arrays:
- temperature: Padded with -1.0 (float32)
- min_p: Padded with 0.0 (float32)
- top_p: Padded with 1.0 (float32)
- top_k: Padded with 0 (int32)
Note:
Default padding values are chosen to be neutral for sampling operations.
"""
def fill(arr, fill_val):
out = jnp.full((padded_num_reqs,), fill_val, dtype=arr.dtype)
return out.at[:num_reqs].set(arr[:num_reqs])
return (
fill(temperature, -1.0).astype(jnp.float32),
fill(min_p, 0.0).astype(jnp.float32),
fill(top_p, 1.0).astype(jnp.float32),
fill(top_k, 0).astype(jnp.int32),
)
[docs]def swap_rows(arr, i1, i2):
"""Swap two rows in an array.
Args:
arr: Input array to swap rows in.
i1: Index of first row.
i2: Index of second row.
Returns:
Array with rows i1 and i2 swapped.
Note:
Works for both NumPy ndarrays and JAX arrays. Always returns a new array.
"""
if isinstance(arr, np.ndarray):
out = np.asarray(arr).copy()
tmp = out[i1].copy()
out[i1] = out[i2]
out[i2] = tmp
return out
if hasattr(arr, "at"):
row_i1 = arr[i1]
row_i2 = arr[i2]
arr = arr.at[i1].set(row_i2)
arr = arr.at[i2].set(row_i1)
return arr
# Fallback for Python sequences
out = np.asarray(arr).copy()
tmp = out[i1].copy()
out[i1] = out[i2]
out[i2] = tmp
return out
[docs]def swap_rows_pytree(arrs, i1, i2):
"""Swap rows across all arrays in a pytree.
Args:
arrs: PyTree containing arrays.
i1: Index of first row to swap.
i2: Index of second row to swap.
Returns:
PyTree with same structure but rows swapped in all arrays.
"""
return jax.tree_map(lambda a: swap_rows(a, i1, i2), arrs)
[docs]def move_row(arr, from_idx, to_idx):
"""Move a row from one index to another.
Args:
arr: Input array.
from_idx: Source row index.
to_idx: Destination row index.
Returns:
Array with row moved from from_idx to to_idx.
Note:
Works for both NumPy ndarrays and JAX arrays. Always returns a new array.
"""
if isinstance(arr, np.ndarray):
out = np.asarray(arr).copy()
out[to_idx] = out[from_idx]
return out
if hasattr(arr, "at"):
return arr.at[to_idx].set(arr[from_idx])
out = np.asarray(arr).copy()
out[to_idx] = out[from_idx]
return out
[docs]@ejit(static_argnames=("vocab_size", "max_allowed"))
def build_allowed_mask(allowed_ids_padded, allowed_lens, vocab_size, max_allowed):
"""Build a mask for allowed token IDs.
Creates a boolean mask indicating which tokens are allowed for each request.
The mask uses inverted logic where True means disallowed and False means allowed.
Args:
allowed_ids_padded: Padded array of allowed token IDs [B, max_allowed].
allowed_lens: Number of valid allowed IDs per request [B].
vocab_size: Total vocabulary size.
max_allowed: Maximum number of allowed tokens per request.
Returns:
Boolean mask of shape [B, vocab_size] where True indicates
the token is disallowed and False indicates it's allowed.
Note:
The inverted logic (True=disallowed) is used for compatibility
with masking operations that zero out disallowed values.
"""
B = allowed_ids_padded.shape[0]
mask = jnp.ones((B, vocab_size), dtype=bool)
batch_idx = jnp.repeat(jnp.arange(B)[:, None], max_allowed, axis=1) # [B, max_allowed]
flat_batch = batch_idx.reshape(-1)
flat_token = allowed_ids_padded.reshape(-1)
ar = jnp.arange(max_allowed)[None, :]
valid = ar < allowed_lens[:, None]
flat_valid = valid.reshape(-1)
flat_batch = flat_batch[flat_valid]
flat_token = flat_token[flat_valid]
mask = mask.at[flat_batch, flat_token].set(False)
return mask
[docs]class SequenceBuffer:
"""Buffer for managing token sequences during generation.
Mutable, class-based design:
- All arrays are mutable instance attributes
- Methods modify state in-place (return None)
- NumPy arrays stay on CPU for fast metadata operations
- PageTable manages device-side KV cache allocations
- Simplified mental model: direct state mutations
Use SequenceBuffer() constructor to create instances.
"""
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
vocab_size: int,
page_sizes: list[int],
sharding: jax.sharding.Sharding | None = None,
):
"""Initialize a SequenceBuffer with all arrays and page table.
Args:
max_num_reqs: Maximum number of concurrent requests.
max_model_len: Maximum sequence length per request.
max_num_batched_tokens: Maximum tokens in a batch.
vocab_size: Size of the model vocabulary.
page_sizes: List of page sizes for the page table.
sharding: Optional JAX sharding for page table arrays.
"""
# Static configuration
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
# Initialize all NumPy arrays (CPU-side)
self.token_ids = np.zeros((max_num_reqs, max_model_len), dtype=np.int32)
self.num_tokens = np.zeros((max_num_reqs,), dtype=np.int32)
self.num_tokens_no_spec = np.zeros((max_num_reqs,), dtype=np.int32)
self.num_prompt_tokens = np.zeros((max_num_reqs,), dtype=np.int32)
self.num_computed_tokens = np.zeros((max_num_reqs,), dtype=np.int32)
self.temperature = np.full((max_num_reqs,), -1.0, dtype=np.float32)
self.top_p = np.ones((max_num_reqs,), dtype=np.float32)
self.top_k = np.full((max_num_reqs,), vocab_size, dtype=np.int32)
self.min_p = np.zeros((max_num_reqs,), dtype=np.float32)
self.frequency_penalties = np.zeros((max_num_reqs,), dtype=np.float32)
self.presence_penalties = np.zeros((max_num_reqs,), dtype=np.float32)
self.repetition_penalties = np.ones((max_num_reqs,), dtype=np.float32)
# Page table
self.page_table = MultiGroupPageTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
page_sizes=page_sizes,
)
# Apply sharding only to page_table if provided
if sharding is not None:
def put(a):
return jax.device_put(a, sharding)
self.page_table = jax.tree_util.tree_map(lambda x: put(x) if hasattr(x, "dtype") else x, self.page_table)
# Python bookkeeping
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
self.req_output_token_ids: list[list[int] | None] = []
self.request_distribution: list[int] = [0, 0, 0]
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p_reqs: set[str] = set()
self.top_k_reqs: set[str] = set()
self.min_p_reqs: set[str] = set()
self.frequency_penalties_reqs: set[str] = set()
self.presence_penalties_reqs: set[str] = set()
self.repetition_penalties_reqs: set[str] = set()
self.has_allowed_token_ids: set[str] = set()
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
self.generator_seeds: dict[int, int] = {}
self.num_logprobs: dict[str, int] = {}
self.num_prompt_logprobs: dict[str, int] = {}
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[dict[int, float] | None] = [None] * max_num_reqs
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.allowed_token_ids_mask: jax.Array | None = None
def _update_request_distribution(self) -> None:
"""Update the request distribution triple [decode_only, chunked_prefill, total]."""
total = len(self.req_id_to_index)
self.request_distribution = [0, 0, total]
@property
def req_ids(self) -> list[str]:
return cast(list[str], self._req_ids)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (
len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0
)
@property
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
def _ensure_logit_bias_capacity(self, upto_idx: int) -> None:
"""Ensure logit_bias list has sufficient capacity.
Args:
upto_idx: Index that needs to be accessible.
Note:
Extends the list with None values if needed.
"""
if len(self.logit_bias) <= upto_idx:
self.logit_bias.extend([None] * (upto_idx + 1 - len(self.logit_bias)))
[docs] def add_request(self, request: EngineRequest, req_index: int | None = None) -> None:
"""Add a new request to the buffer.
Adds a request with its tokens, sampling parameters, and metadata.
Handles prompt truncation if it exceeds maximum length.
Args:
request: The engine request to add containing:
- req_id: Unique request identifier
- prompt_token_ids: Input prompt tokens
- sampling_params: Sampling configuration
- page_ids: Page allocation for KV cache
req_index: Optional specific index to place the request.
If None, finds the next available slot.
Raises:
ValueError: If the request ID already exists in the buffer.
IndexError: If req_index is out of bounds.
RuntimeError: If the buffer is full.
Note:
This method modifies the buffer in-place.
"""
req_id = request.req_id
if req_id in self.req_id_to_index:
raise ValueError(f"Request ID {req_id} is already present at index {self.req_id_to_index[req_id]}.")
req_index = self._allocate_index(req_index)
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy tokens into arrays (NumPy in-place updates)
num_prompt_tokens_val = min(len(request.prompt_token_ids), self.max_model_len)
new_num_prompt_tokens = self.num_prompt_tokens.copy()
new_num_prompt_tokens[req_index] = num_prompt_tokens_val
new_token_ids = self.token_ids.copy()
new_token_ids[req_index, :num_prompt_tokens_val] = np.array(
request.prompt_token_ids[:num_prompt_tokens_val], dtype=np.int32
)
if request.output_token_ids:
start_idx = num_prompt_tokens_val
max_output_tokens = self.max_model_len - num_prompt_tokens_val
output_tokens_to_copy = request.output_token_ids[:max_output_tokens]
if output_tokens_to_copy:
end_idx = min(start_idx + len(output_tokens_to_copy), self.max_model_len)
new_token_ids[req_index, start_idx:end_idx] = np.array(output_tokens_to_copy, dtype=np.int32)
capped_num_tokens = min(int(request.num_tokens), self.max_model_len)
new_num_tokens = self.num_tokens.copy()
new_num_tokens[req_index] = capped_num_tokens
new_num_tokens_no_spec = self.num_tokens_no_spec.copy()
new_num_tokens_no_spec[req_index] = capped_num_tokens
new_num_computed_tokens = self.num_computed_tokens.copy()
new_num_computed_tokens[req_index] = min(int(request.num_computed_tokens), self.max_model_len)
# Update arrays in-place
self.token_ids = new_token_ids
self.num_prompt_tokens = new_num_prompt_tokens
self.num_tokens = new_num_tokens
self.num_tokens_no_spec = new_num_tokens_no_spec
self.num_computed_tokens = new_num_computed_tokens
# Page table - mutate in-place
self.page_table.add_row(request.page_ids, req_index)
self.page_table.commit(self.num_reqs)
# Sampling params
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
self._process_sampling_params(sampling_params, req_id, req_index)
self._process_optional_params(request, sampling_params, req_id, req_index)
self._update_request_distribution()
[docs] def remove_request(self, req_id: str) -> int | None:
"""Remove a request from the buffer.
Removes all data associated with a request ID and cleans up
related bookkeeping structures.
Args:
req_id: The request ID to remove.
Returns:
The index where the request was removed, or None if not found.
Note:
This method modifies the buffer in-place.
Should typically be followed by condense() to remove gaps
in the buffer and maintain efficiency.
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
for req_set in [
self.greedy_reqs,
self.random_reqs,
self.top_p_reqs,
self.top_k_reqs,
self.min_p_reqs,
self.frequency_penalties_reqs,
self.presence_penalties_reqs,
self.repetition_penalties_reqs,
self.has_allowed_token_ids,
]:
req_set.discard(req_id)
self.min_tokens.pop(req_index, None)
self.generator_seeds.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.bad_words_token_ids.pop(req_index, None)
# Guarded indexing
self._ensure_logit_bias_capacity(req_index)
self.logit_bias[req_index] = None
if self.allowed_token_ids_mask is not None:
self.allowed_token_ids_mask = self.allowed_token_ids_mask.at[req_index].set(False)
self._update_request_distribution()
return req_index
[docs] def swap_states(self, i1: int, i2: int) -> None:
"""Swap the states of two requests at given indices.
Exchanges all data (tokens, parameters, metadata) between two
request positions in the buffer.
Args:
i1: Index of the first request.
i2: Index of the second request.
Raises:
AssertionError: If either index doesn't contain a valid request.
Note:
This method modifies the buffer in-place.
Useful for buffer reorganization and optimization.
"""
old_id_i1, old_id_i2 = self._req_ids[i1], self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] = old_id_i2, old_id_i1
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
self.req_output_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1] = i2
self.req_id_to_index[old_id_i2] = i1
# Swap arrays in-place
self.num_tokens = swap_rows(self.num_tokens, i1, i2)
self.num_tokens_no_spec = swap_rows(self.num_tokens_no_spec, i1, i2)
self.num_prompt_tokens = swap_rows(self.num_prompt_tokens, i1, i2)
self.num_computed_tokens = swap_rows(self.num_computed_tokens, i1, i2)
self.temperature = swap_rows(self.temperature, i1, i2)
self.top_p = swap_rows(self.top_p, i1, i2)
self.top_k = swap_rows(self.top_k, i1, i2)
self.frequency_penalties = swap_rows(self.frequency_penalties, i1, i2)
self.presence_penalties = swap_rows(self.presence_penalties, i1, i2)
self.repetition_penalties = swap_rows(self.repetition_penalties, i1, i2)
self.min_p = swap_rows(self.min_p, i1, i2)
self.token_ids = swap_rows(self.token_ids, i1, i2)
if self.allowed_token_ids_mask is not None:
self.allowed_token_ids_mask = swap_rows(self.allowed_token_ids_mask, i1, i2)
swap_dict_values(self.generator_seeds, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.logit_bias[i1], self.logit_bias[i2] = self.logit_bias[i2], self.logit_bias[i1]
# Page table - mutate in-place
self.page_table.swap_row(i1, i2)
self.page_table.commit(self.num_reqs)
self._update_request_distribution()
[docs] def condense(self, empty_req_indices: list[int]) -> None:
"""Condense the buffer by removing gaps.
Moves requests from the end of the buffer to fill empty slots,
maintaining a contiguous block of active requests at the beginning.
Args:
empty_req_indices: List of indices that are now empty and need filling.
Note:
This operation is important for maintaining buffer efficiency
after removing requests. It ensures active requests are packed
at the beginning of the buffer. Modifies the buffer in-place.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
self._req_ids.clear()
self.req_output_token_ids.clear()
return
last_req_index = num_reqs + len(empty_req_indices) - 1
for empty_index in reversed(empty_req_indices):
while last_req_index in empty_req_indices and last_req_index > empty_index:
last_req_index -= 1
if empty_index >= last_req_index:
continue
self._move_request(last_req_index, empty_index)
last_req_index -= 1
del self._req_ids[self.num_reqs :]
del self.req_output_token_ids[self.num_reqs :]
self._update_request_distribution()
def _move_request(self, from_idx: int, to_idx: int) -> None:
"""Move a request from one index to another.
Internal method for relocating a request within the buffer.
Args:
from_idx: Source index of the request.
to_idx: Destination index for the request.
Raises:
AssertionError: If from_idx doesn't contain a valid request.
Note:
This is an internal method used by condense() and other
buffer reorganization operations. Modifies the buffer in-place.
"""
req_id = self._req_ids[from_idx]
assert req_id is not None
# Static bookkeeping
self._req_ids[to_idx] = req_id
self._req_ids[from_idx] = None
self.req_output_token_ids[to_idx] = self.req_output_token_ids[from_idx]
self.req_output_token_ids[from_idx] = None
self.req_id_to_index[req_id] = to_idx
# Arrays - update in-place
self.token_ids = move_row(self.token_ids, from_idx, to_idx)
self.num_tokens = move_row(self.num_tokens, from_idx, to_idx)
self.num_tokens_no_spec = move_row(self.num_tokens_no_spec, from_idx, to_idx)
self.num_prompt_tokens = move_row(self.num_prompt_tokens, from_idx, to_idx)
self.num_computed_tokens = move_row(self.num_computed_tokens, from_idx, to_idx)
self.temperature = move_row(self.temperature, from_idx, to_idx)
self.top_p = move_row(self.top_p, from_idx, to_idx)
self.top_k = move_row(self.top_k, from_idx, to_idx)
self.frequency_penalties = move_row(self.frequency_penalties, from_idx, to_idx)
self.presence_penalties = move_row(self.presence_penalties, from_idx, to_idx)
self.repetition_penalties = move_row(self.repetition_penalties, from_idx, to_idx)
self.min_p = move_row(self.min_p, from_idx, to_idx)
# Page table - mutate in-place
self.page_table.move_row(from_idx, to_idx)
self.page_table.commit(self.num_reqs)
# Sparse/optional data
self._move_sparse_data(from_idx, to_idx)
def _move_sparse_data(self, from_idx: int, to_idx: int) -> None:
"""Move sparse and optional data between indices.
Handles the movement of data that may not exist for all requests,
such as generator seeds, min_tokens, bad words, etc.
Args:
from_idx: Source index.
to_idx: Destination index.
Note:
This method complements _move_request() by handling
optional parameters that aren't stored in the main arrays.
Modifies the buffer in-place.
"""
if from_idx in self.generator_seeds:
self.generator_seeds[to_idx] = self.generator_seeds.pop(from_idx)
if from_idx in self.min_tokens:
self.min_tokens[to_idx] = self.min_tokens.pop(from_idx)
if from_idx in self.bad_words_token_ids:
self.bad_words_token_ids[to_idx] = self.bad_words_token_ids.pop(from_idx)
self.logit_bias[to_idx] = self.logit_bias[from_idx]
self.logit_bias[from_idx] = None
if self.allowed_token_ids_mask is not None:
self.allowed_token_ids_mask = self.allowed_token_ids_mask.at[to_idx].set(
self.allowed_token_ids_mask[from_idx]
)
self.allowed_token_ids_mask = self.allowed_token_ids_mask.at[from_idx].set(False)
def _process_sampling_params(self, sampling_params: SamplingParams, req_id: str, req_index: int) -> None:
"""Process and store core sampling parameters.
Updates arrays with sampling configuration like temperature, top_p, top_k, etc.
Also maintains sets tracking which requests use which sampling strategies.
Args:
sampling_params: Sampling configuration containing temperature, top_p, etc.
req_id: Request identifier for bookkeeping.
req_index: Index where parameters should be stored.
Note:
Maintains separate tracking sets for different sampling strategies
to enable optimized execution paths. Modifies the buffer in-place.
"""
# Copy arrays for modification (NumPy style)
temperature = self.temperature.copy()
top_p = self.top_p.copy()
top_k = self.top_k.copy()
min_p = self.min_p.copy()
frequency_penalties = self.frequency_penalties.copy()
presence_penalties = self.presence_penalties.copy()
repetition_penalties = self.repetition_penalties.copy()
if sampling_params.sampling_type == SamplingType.GREEDY:
temperature[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
temperature[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
top_p[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
tk = sampling_params.top_k
if 0 < tk < self.vocab_size:
self.top_k_reqs.add(req_id)
top_k[req_index] = tk
else:
top_k[req_index] = self.vocab_size
min_p[req_index] = sampling_params.min_p
if sampling_params.min_p > 1e-5:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
frequency_penalties[req_index] = sampling_params.frequency_penalty
self.frequency_penalties_reqs.add(req_id)
if sampling_params.presence_penalty != 0.0:
presence_penalties[req_index] = sampling_params.presence_penalty
self.presence_penalties_reqs.add(req_id)
if sampling_params.repetition_penalty != 1.0:
repetition_penalties[req_index] = sampling_params.repetition_penalty
self.repetition_penalties_reqs.add(req_id)
# Update arrays in-place
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.frequency_penalties = frequency_penalties
self.presence_penalties = presence_penalties
self.repetition_penalties = repetition_penalties
def _process_optional_params(
self,
request: EngineRequest,
sampling_params: SamplingParams,
req_id: str,
req_index: int,
) -> None:
"""Process optional and sparse sampling parameters.
Handles parameters that may not be present for all requests,
such as logit bias, allowed tokens, bad words, etc.
Args:
request: The engine request containing optional metadata.
sampling_params: Sampling parameters with optional fields.
req_id: Request identifier.
req_index: Index for parameter storage.
Note:
These parameters are stored in sparse data structures
to avoid memory overhead for unused features.
Modifies the buffer in-place.
"""
if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens, sampling_params.all_stop_token_ids)
if hasattr(request, "generator_seed") and request.generator_seed is not None:
self.generator_seeds[req_index] = request.generator_seed
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
if len(self.logit_bias) < self.max_num_reqs:
# Ensure list length
self.logit_bias.extend([None] * (self.max_num_reqs - len(self.logit_bias)))
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self._set_allowed_token_ids(req_id, req_index, sampling_params.allowed_token_ids)
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids
def _set_allowed_token_ids(self, req_id: str, req_index: int, allowed_token_ids: list[int]) -> None:
"""Set the allowed token IDs for a request.
Creates or updates a mask indicating which tokens are allowed for generation.
Uses inverted logic where True means disallowed.
Args:
req_id: Request identifier.
req_index: Index of the request.
allowed_token_ids: List of token IDs that are allowed.
Raises:
ValueError: If any token ID is outside the valid vocabulary range.
Note:
The mask uses inverted logic (True=disallowed) for compatibility
with JAX masking operations. Modifies the buffer in-place.
"""
if any((t < 0 or t >= self.vocab_size) for t in allowed_token_ids):
raise ValueError(f"allowed_token_ids must be within [0, {self.vocab_size})")
self.has_allowed_token_ids.add(req_id)
mask = self.allowed_token_ids_mask
if mask is None:
mask = jnp.zeros((self.max_num_reqs, self.vocab_size), dtype=bool)
# Start with all True (disallowed) for this row, then set allowed ones to False
mask = mask.at[req_index].set(True)
if allowed_token_ids:
mask = mask.at[req_index, allowed_token_ids].set(False)
self.allowed_token_ids_mask = mask
def _allocate_index(self, req_index: int | None) -> int:
"""Allocate an index for a new request.
Finds or validates an index position for placing a new request.
Args:
req_index: Optional preferred index. If None, finds next available.
Returns:
The allocated index.
Raises:
IndexError: If req_index exceeds maximum capacity.
ValueError: If req_index is already occupied.
RuntimeError: If buffer is full and no index is available.
Note:
This method may extend internal bookkeeping lists as needed.
"""
if req_index is not None:
if req_index >= self.max_num_reqs:
raise IndexError(f"req_index {req_index} >= max_num_reqs {self.max_num_reqs}")
while len(self._req_ids) < req_index:
self._req_ids.append(None)
self.req_output_token_ids.append(None)
if req_index < len(self._req_ids) and self._req_ids[req_index] is not None:
raise ValueError(f"req_index {req_index} is already occupied by {self._req_ids[req_index]}")
return req_index
for i, rid in enumerate(self._req_ids):
if rid is None:
return i
if len(self._req_ids) < self.max_num_reqs:
return len(self._req_ids)
raise RuntimeError("SequenceBuffer is full; cannot allocate a new request index.")
def _make_prompt_token_ids_tensor(self) -> jax.Array:
"""Create a padded tensor of prompt token IDs.
Returns:
A padded array of prompt tokens suitable for batch processing.
Shape is [num_reqs, max_prompt_len].
Note:
Uses the JIT-compiled pack_prompts function for efficiency.
"""
if self.num_reqs == 0:
return jnp.empty((0, 0), dtype=jnp.int32)
max_prompt_len = int(np.max(self.num_prompt_tokens[: self.num_reqs]))
# Convert to JAX for pack_prompts (which is JIT-compiled)
return pack_prompts(
jnp.asarray(self.token_ids),
jnp.asarray(self.num_prompt_tokens),
self.num_reqs,
max_prompt_len,
self.vocab_size,
)
[docs] def get_request_indices_with_penalty(self) -> jax.Array:
"""Get indices of requests with penalties.
Returns:
Array of indices for requests that have frequency, presence,
or repetition penalties applied.
Note:
Used to optimize penalty application by only processing
requests that actually need it.
"""
penalty_req_ids = self.frequency_penalties_reqs | self.presence_penalties_reqs | self.repetition_penalties_reqs
if not penalty_req_ids:
return jnp.array([], dtype=jnp.int32)
indices = [self.req_id_to_index[req_id] for req_id in penalty_req_ids]
return jnp.array(indices, dtype=jnp.int32)
[docs] def get_active_sampling_params(self, req_index: int) -> dict[str, Any]:
"""Get active sampling parameters for a request.
Args:
req_index: Index of the request.
Returns:
Dictionary containing active sampling parameters for the request.
Only includes parameters that are actually in use.
Note:
Returns empty dict if the index doesn't contain a valid request.
"""
req_id = self._req_ids[req_index]
if req_id is None:
return {}
params = {
"temperature": self.temperature[req_index],
"top_p": self.top_p[req_index],
"top_k": self.top_k[req_index],
}
if req_id in self.min_p_reqs:
params["min_p"] = self.min_p[req_index]
if req_id in self.frequency_penalties_reqs:
params["frequency_penalty"] = self.frequency_penalties[req_index]
if req_id in self.presence_penalties_reqs:
params["presence_penalty"] = self.presence_penalties[req_index]
if req_id in self.repetition_penalties_reqs:
params["repetition_penalty"] = self.repetition_penalties[req_index]
return params
[docs] def clear(self) -> None:
"""Clear all data in the buffer.
Resets all arrays to their initial values and clears all bookkeeping.
Note:
This maintains the buffer structure and capacity but removes
all request data. Modifies the buffer in-place.
"""
self._req_ids.clear()
self.req_id_to_index.clear()
self.req_output_token_ids.clear()
self.token_ids = np.zeros_like(self.token_ids)
self.num_tokens = np.zeros_like(self.num_tokens)
self.num_tokens_no_spec = np.zeros_like(self.num_tokens_no_spec)
self.num_prompt_tokens = np.zeros_like(self.num_prompt_tokens)
self.num_computed_tokens = np.zeros_like(self.num_computed_tokens)
self.temperature = np.full_like(self.temperature, -1.0)
self.top_p = np.ones_like(self.top_p)
self.top_k = np.full_like(self.top_k, self.vocab_size)
self.min_p = np.zeros_like(self.min_p)
self.frequency_penalties = np.zeros_like(self.frequency_penalties)
self.presence_penalties = np.zeros_like(self.presence_penalties)
self.repetition_penalties = np.ones_like(self.repetition_penalties)
for req_set in [
self.greedy_reqs,
self.random_reqs,
self.top_p_reqs,
self.top_k_reqs,
self.min_p_reqs,
self.frequency_penalties_reqs,
self.presence_penalties_reqs,
self.repetition_penalties_reqs,
self.has_allowed_token_ids,
]:
req_set.clear()
self.min_tokens.clear()
self.generator_seeds.clear()
self.num_logprobs.clear()
self.num_prompt_logprobs.clear()
self.in_progress_prompt_logprobs_cpu.clear()
self.bad_words_token_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
# Page table - mutate in-place (no commit needed for clear)
self.page_table.clear()
# Clear allowed token IDs mask
if self.allowed_token_ids_mask is not None:
self.allowed_token_ids_mask = jnp.zeros_like(self.allowed_token_ids_mask, dtype=bool)
self._update_request_distribution()
[docs]@ejit(static_argnums=(2, 3))
def fill_slice(arr, fill_val, num_reqs, padded_num_reqs):
"""Fill array slice with padding value.
Args:
arr: Input array to pad.
fill_val: Value to use for padding.
num_reqs: Number of valid requests.
padded_num_reqs: Target padded size.
Returns:
Array with padding applied from num_reqs to padded_num_reqs.
"""
return arr.at[num_reqs:padded_num_reqs].set(fill_val)[:padded_num_reqs]