easydel.inference.esurge.runners.sequence_buffer#
Sequence buffer for managing token sequences during generation.
Provides efficient storage and management of token sequences, page tables, and generation metadata for batch processing.
- Classes:
SequenceBuffer: Main buffer for managing sequences DecodeRowInfo: Information about sequences in decode phase ModelRunBatch: Batch data for model execution
Example
>>> buffer = SequenceBuffer(
... max_num_reqs=32,
... max_model_len=2048,
... max_num_batched_tokens=4096,
... vocab_size=50000,
... page_sizes=[16, 32]
... )
>>> buffer.begin_sequence("req_1", [1, 2, 3])
>>> batch = buffer.prepare_batch()
- class easydel.inference.esurge.runners.sequence_buffer.ModelRunnerSamplingMetadata(temperature: ~jax.Array, min_p: ~jax.Array, top_k: ~jax.Array, top_p: ~jax.Array, all_greedy: bool = True, logprobs: bool = False, no_penalties: bool = True, prompt_token_ids: ~typing.Any = None, frequency_penalties: ~typing.Any = None, presence_penalties: ~typing.Any = None, repetition_penalties: ~typing.Any = None, output_token_ids: list[list[int]] = <factory>, min_tokens: ~typing.Any = None, logit_bias: list[dict[int, float]] = <factory>, allowed_token_ids_mask: ~typing.Any = None, bad_words_token_ids: ~typing.Any = None)[source]#
Bases:
objectMetadata for sampling operations during model execution.
Contains sampling parameters and optional penalty/constraint data for batch processing during inference.
- all_greedy#
Whether all requests use greedy sampling.
- Type
bool
- logprobs#
Whether to compute log probabilities.
- Type
bool
- no_penalties#
Whether penalties are disabled.
- Type
bool
- prompt_token_ids#
Optional prompt tokens for context.
- Type
Any
- frequency_penalties#
Optional frequency penalties.
- Type
Any
- presence_penalties#
Optional presence penalties.
- Type
Any
- repetition_penalties#
Optional repetition penalties.
- Type
Any
- output_token_ids#
Generated output tokens.
- Type
list[list[int]]
- min_tokens#
Minimum tokens to generate.
- Type
Any
- logit_bias#
Per-token logit adjustments.
- Type
list[dict[int, float]]
- allowed_token_ids_mask#
Mask for allowed tokens.
- Type
Any
- bad_words_token_ids#
Tokens to avoid generating.
- Type
Any
- all_greedy: bool = True#
- allowed_token_ids_mask: Any = None#
- bad_words_token_ids: Any = None#
- frequency_penalties: Any = None#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod from_sequence_buffer(sequence_buffer: SequenceBuffer, padded_num_reqs: int, generate_params_if_all_greedy: bool = False)[source]#
Create sampling metadata from a sequence buffer.
- Parameters
sequence_buffer – Source buffer containing sampling parameters.
padded_num_reqs – Target padded number of requests.
generate_params_if_all_greedy – Whether to generate parameters even when all requests use greedy sampling.
- Returns
ModelRunnerSamplingMetadata with padded sampling arrays.
Note
If all requests use greedy sampling and generate_params_if_all_greedy is False, returns zero-filled arrays for efficiency.
- logit_bias: list[dict[int, float]]#
- logprobs: bool = False#
- min_tokens: Any = None#
- no_penalties: bool = True#
- output_token_ids: list[list[int]]#
- presence_penalties: Any = None#
- prompt_token_ids: Any = None#
- repetition_penalties: Any = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.esurge.runners.sequence_buffer.SequenceBuffer(max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, vocab_size: int, page_sizes: list[int], sharding: jaxlib._jax.Sharding | None = None)[source]#
Bases:
objectBuffer for managing token sequences during generation.
- Mutable, class-based design:
All arrays are mutable instance attributes
Methods modify state in-place (return None)
NumPy arrays stay on CPU for fast metadata operations
PageTable manages device-side KV cache allocations
Simplified mental model: direct state mutations
Use SequenceBuffer() constructor to create instances.
- add_request(request: EngineRequest, req_index: int | None = None) None[source]#
Add a new request to the buffer.
Adds a request with its tokens, sampling parameters, and metadata. Handles prompt truncation if it exceeds maximum length.
- Parameters
request – The engine request to add containing: - req_id: Unique request identifier - prompt_token_ids: Input prompt tokens - sampling_params: Sampling configuration - page_ids: Page allocation for KV cache
req_index – Optional specific index to place the request. If None, finds the next available slot.
- Raises
ValueError – If the request ID already exists in the buffer.
IndexError – If req_index is out of bounds.
RuntimeError – If the buffer is full.
Note
This method modifies the buffer in-place.
- property all_greedy: bool#
- property all_random: bool#
- clear() None[source]#
Clear all data in the buffer.
Resets all arrays to their initial values and clears all bookkeeping.
Note
This maintains the buffer structure and capacity but removes all request data. Modifies the buffer in-place.
- condense(empty_req_indices: list[int]) None[source]#
Condense the buffer by removing gaps.
Moves requests from the end of the buffer to fill empty slots, maintaining a contiguous block of active requests at the beginning.
- Parameters
empty_req_indices – List of indices that are now empty and need filling.
Note
This operation is important for maintaining buffer efficiency after removing requests. It ensures active requests are packed at the beginning of the buffer. Modifies the buffer in-place.
- get_active_sampling_params(req_index: int) dict[str, Any][source]#
Get active sampling parameters for a request.
- Parameters
req_index – Index of the request.
- Returns
Dictionary containing active sampling parameters for the request. Only includes parameters that are actually in use.
Note
Returns empty dict if the index doesn’t contain a valid request.
- get_request_indices_with_penalty() Array[source]#
Get indices of requests with penalties.
- Returns
Array of indices for requests that have frequency, presence, or repetition penalties applied.
Note
Used to optimize penalty application by only processing requests that actually need it.
- property no_allowed_token_ids: bool#
- property no_min_p: bool#
- property no_penalties: bool#
- property no_prompt_logprob: bool#
- property no_top_k: bool#
- property no_top_p: bool#
- property num_reqs: int#
- remove_request(req_id: str) int | None[source]#
Remove a request from the buffer.
Removes all data associated with a request ID and cleans up related bookkeeping structures.
- Parameters
req_id – The request ID to remove.
- Returns
The index where the request was removed, or None if not found.
Note
This method modifies the buffer in-place. Should typically be followed by condense() to remove gaps in the buffer and maintain efficiency.
- property req_ids: list[str]#
- swap_states(i1: int, i2: int) None[source]#
Swap the states of two requests at given indices.
Exchanges all data (tokens, parameters, metadata) between two request positions in the buffer.
- Parameters
i1 – Index of the first request.
i2 – Index of the second request.
- Raises
AssertionError – If either index doesn’t contain a valid request.
Note
This method modifies the buffer in-place. Useful for buffer reorganization and optimization.
- easydel.inference.esurge.runners.sequence_buffer.build_allowed_mask(allowed_ids_padded, allowed_lens, vocab_size, max_allowed)[source]#
Build a mask for allowed token IDs.
Creates a boolean mask indicating which tokens are allowed for each request. The mask uses inverted logic where True means disallowed and False means allowed.
- Parameters
allowed_ids_padded – Padded array of allowed token IDs [B, max_allowed].
allowed_lens – Number of valid allowed IDs per request [B].
vocab_size – Total vocabulary size.
max_allowed – Maximum number of allowed tokens per request.
- Returns
Boolean mask of shape [B, vocab_size] where True indicates the token is disallowed and False indicates it’s allowed.
Note
The inverted logic (True=disallowed) is used for compatibility with masking operations that zero out disallowed values.
- easydel.inference.esurge.runners.sequence_buffer.build_sampling_arrays(temperature, min_p, top_p, top_k, num_reqs, padded_num_reqs)[source]#
Build padded sampling parameter arrays.
Pads sampling parameters to a consistent size for batch processing, filling unused slots with default values.
- Parameters
temperature – Temperature values for sampling.
min_p – Minimum probability threshold values.
top_p – Top-p (nucleus) sampling values.
top_k – Top-k sampling values.
num_reqs – Actual number of requests.
padded_num_reqs – Target padded number of requests.
- Returns
temperature: Padded with -1.0 (float32)
min_p: Padded with 0.0 (float32)
top_p: Padded with 1.0 (float32)
top_k: Padded with 0 (int32)
- Return type
A tuple of padded arrays
Note
Default padding values are chosen to be neutral for sampling operations.
- easydel.inference.esurge.runners.sequence_buffer.fill_slice(arr, fill_val, num_reqs, padded_num_reqs)[source]#
Fill array slice with padding value.
- Parameters
arr – Input array to pad.
fill_val – Value to use for padding.
num_reqs – Number of valid requests.
padded_num_reqs – Target padded size.
- Returns
Array with padding applied from num_reqs to padded_num_reqs.
- easydel.inference.esurge.runners.sequence_buffer.move_row(arr, from_idx, to_idx)[source]#
Move a row from one index to another.
- Parameters
arr – Input array.
from_idx – Source row index.
to_idx – Destination row index.
- Returns
Array with row moved from from_idx to to_idx.
Note
Works for both NumPy ndarrays and JAX arrays. Always returns a new array.
- easydel.inference.esurge.runners.sequence_buffer.pack_prompts(token_ids, num_prompt_tokens, padded_num_reqs, padded_prompt_len, pad_id)[source]#
Pack prompt tokens into a padded tensor.
Creates a padded tensor of prompt tokens with consistent shape for batch processing. Tokens beyond the prompt length are replaced with the padding ID.
- Parameters
token_ids – Token IDs array of shape [max_num_reqs, max_model_len].
num_prompt_tokens – Number of prompt tokens per request [max_num_reqs].
padded_num_reqs – Target number of requests after padding.
padded_prompt_len – Target prompt length after padding.
pad_id – Token ID to use for padding.
- Returns
Packed prompts array of shape [padded_num_reqs, padded_prompt_len] with valid tokens and padding.
Note
This function is JIT-compiled with static arguments for padded dimensions to enable efficient compilation caching.
- easydel.inference.esurge.runners.sequence_buffer.swap_rows(arr, i1, i2)[source]#
Swap two rows in an array.
- Parameters
arr – Input array to swap rows in.
i1 – Index of first row.
i2 – Index of second row.
- Returns
Array with rows i1 and i2 swapped.
Note
Works for both NumPy ndarrays and JAX arrays. Always returns a new array.
- easydel.inference.esurge.runners.sequence_buffer.swap_rows_pytree(arrs, i1, i2)[source]#
Swap rows across all arrays in a pytree.
- Parameters
arrs – PyTree containing arrays.
i1 – Index of first row to swap.
i2 – Index of second row to swap.
- Returns
PyTree with same structure but rows swapped in all arrays.