easydel.inference.esurge.runners.sequence_buffer

Contents

easydel.inference.esurge.runners.sequence_buffer#

Sequence buffer for managing token sequences during generation.

Provides efficient storage and management of token sequences, page tables, and generation metadata for batch processing.

Classes:

SequenceBuffer: Main buffer for managing sequences DecodeRowInfo: Information about sequences in decode phase ModelRunBatch: Batch data for model execution

Example

>>> buffer = SequenceBuffer(
...     max_num_reqs=32,
...     max_model_len=2048,
...     max_num_batched_tokens=4096,
...     vocab_size=50000,
...     page_sizes=[16, 32]
... )
>>> buffer.begin_sequence("req_1", [1, 2, 3])
>>> batch = buffer.prepare_batch()
class easydel.inference.esurge.runners.sequence_buffer.ModelRunnerSamplingMetadata(temperature: ~jax.Array, min_p: ~jax.Array, top_k: ~jax.Array, top_p: ~jax.Array, all_greedy: bool = True, logprobs: bool = False, no_penalties: bool = True, prompt_token_ids: ~typing.Any = None, frequency_penalties: ~typing.Any = None, presence_penalties: ~typing.Any = None, repetition_penalties: ~typing.Any = None, output_token_ids: list[list[int]] = <factory>, min_tokens: ~typing.Any = None, logit_bias: list[dict[int, float]] = <factory>, allowed_token_ids_mask: ~typing.Any = None, bad_words_token_ids: ~typing.Any = None)[source]#

Bases: object

Metadata for sampling operations during model execution.

Contains sampling parameters and optional penalty/constraint data for batch processing during inference.

temperature#

Temperature values for sampling.

Type

jax.Array

min_p#

Minimum probability thresholds.

Type

jax.Array

top_k#

Top-k sampling parameters.

Type

jax.Array

top_p#

Top-p (nucleus) sampling parameters.

Type

jax.Array

all_greedy#

Whether all requests use greedy sampling.

Type

bool

logprobs#

Whether to compute log probabilities.

Type

bool

no_penalties#

Whether penalties are disabled.

Type

bool

prompt_token_ids#

Optional prompt tokens for context.

Type

Any

frequency_penalties#

Optional frequency penalties.

Type

Any

presence_penalties#

Optional presence penalties.

Type

Any

repetition_penalties#

Optional repetition penalties.

Type

Any

output_token_ids#

Generated output tokens.

Type

list[list[int]]

min_tokens#

Minimum tokens to generate.

Type

Any

logit_bias#

Per-token logit adjustments.

Type

list[dict[int, float]]

allowed_token_ids_mask#

Mask for allowed tokens.

Type

Any

bad_words_token_ids#

Tokens to avoid generating.

Type

Any

all_greedy: bool = True#
allowed_token_ids_mask: Any = None#
bad_words_token_ids: Any = None#
frequency_penalties: Any = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod from_sequence_buffer(sequence_buffer: SequenceBuffer, padded_num_reqs: int, generate_params_if_all_greedy: bool = False)[source]#

Create sampling metadata from a sequence buffer.

Parameters
  • sequence_buffer – Source buffer containing sampling parameters.

  • padded_num_reqs – Target padded number of requests.

  • generate_params_if_all_greedy – Whether to generate parameters even when all requests use greedy sampling.

Returns

ModelRunnerSamplingMetadata with padded sampling arrays.

Note

If all requests use greedy sampling and generate_params_if_all_greedy is False, returns zero-filled arrays for efficiency.

logit_bias: list[dict[int, float]]#
logprobs: bool = False#
min_p: Array#
min_tokens: Any = None#
no_penalties: bool = True#
output_token_ids: list[list[int]]#
presence_penalties: Any = None#
prompt_token_ids: Any = None#
repetition_penalties: Any = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: Array#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: Array#
top_p: Array#
class easydel.inference.esurge.runners.sequence_buffer.SequenceBuffer(max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, vocab_size: int, page_sizes: list[int], sharding: jaxlib._jax.Sharding | None = None)[source]#

Bases: object

Buffer for managing token sequences during generation.

Mutable, class-based design:
  • All arrays are mutable instance attributes

  • Methods modify state in-place (return None)

  • NumPy arrays stay on CPU for fast metadata operations

  • PageTable manages device-side KV cache allocations

  • Simplified mental model: direct state mutations

Use SequenceBuffer() constructor to create instances.

add_request(request: EngineRequest, req_index: int | None = None) None[source]#

Add a new request to the buffer.

Adds a request with its tokens, sampling parameters, and metadata. Handles prompt truncation if it exceeds maximum length.

Parameters
  • request – The engine request to add containing: - req_id: Unique request identifier - prompt_token_ids: Input prompt tokens - sampling_params: Sampling configuration - page_ids: Page allocation for KV cache

  • req_index – Optional specific index to place the request. If None, finds the next available slot.

Raises
  • ValueError – If the request ID already exists in the buffer.

  • IndexError – If req_index is out of bounds.

  • RuntimeError – If the buffer is full.

Note

This method modifies the buffer in-place.

property all_greedy: bool#
property all_random: bool#
clear() None[source]#

Clear all data in the buffer.

Resets all arrays to their initial values and clears all bookkeeping.

Note

This maintains the buffer structure and capacity but removes all request data. Modifies the buffer in-place.

condense(empty_req_indices: list[int]) None[source]#

Condense the buffer by removing gaps.

Moves requests from the end of the buffer to fill empty slots, maintaining a contiguous block of active requests at the beginning.

Parameters

empty_req_indices – List of indices that are now empty and need filling.

Note

This operation is important for maintaining buffer efficiency after removing requests. It ensures active requests are packed at the beginning of the buffer. Modifies the buffer in-place.

get_active_sampling_params(req_index: int) dict[str, Any][source]#

Get active sampling parameters for a request.

Parameters

req_index – Index of the request.

Returns

Dictionary containing active sampling parameters for the request. Only includes parameters that are actually in use.

Note

Returns empty dict if the index doesn’t contain a valid request.

get_request_indices_with_penalty() Array[source]#

Get indices of requests with penalties.

Returns

Array of indices for requests that have frequency, presence, or repetition penalties applied.

Note

Used to optimize penalty application by only processing requests that actually need it.

property max_num_logprobs: int | None#
property no_allowed_token_ids: bool#
property no_min_p: bool#
property no_penalties: bool#
property no_prompt_logprob: bool#
property no_top_k: bool#
property no_top_p: bool#
property num_reqs: int#
remove_request(req_id: str) int | None[source]#

Remove a request from the buffer.

Removes all data associated with a request ID and cleans up related bookkeeping structures.

Parameters

req_id – The request ID to remove.

Returns

The index where the request was removed, or None if not found.

Note

This method modifies the buffer in-place. Should typically be followed by condense() to remove gaps in the buffer and maintain efficiency.

property req_ids: list[str]#
swap_states(i1: int, i2: int) None[source]#

Swap the states of two requests at given indices.

Exchanges all data (tokens, parameters, metadata) between two request positions in the buffer.

Parameters
  • i1 – Index of the first request.

  • i2 – Index of the second request.

Raises

AssertionError – If either index doesn’t contain a valid request.

Note

This method modifies the buffer in-place. Useful for buffer reorganization and optimization.

easydel.inference.esurge.runners.sequence_buffer.build_allowed_mask(allowed_ids_padded, allowed_lens, vocab_size, max_allowed)[source]#

Build a mask for allowed token IDs.

Creates a boolean mask indicating which tokens are allowed for each request. The mask uses inverted logic where True means disallowed and False means allowed.

Parameters
  • allowed_ids_padded – Padded array of allowed token IDs [B, max_allowed].

  • allowed_lens – Number of valid allowed IDs per request [B].

  • vocab_size – Total vocabulary size.

  • max_allowed – Maximum number of allowed tokens per request.

Returns

Boolean mask of shape [B, vocab_size] where True indicates the token is disallowed and False indicates it’s allowed.

Note

The inverted logic (True=disallowed) is used for compatibility with masking operations that zero out disallowed values.

easydel.inference.esurge.runners.sequence_buffer.build_sampling_arrays(temperature, min_p, top_p, top_k, num_reqs, padded_num_reqs)[source]#

Build padded sampling parameter arrays.

Pads sampling parameters to a consistent size for batch processing, filling unused slots with default values.

Parameters
  • temperature – Temperature values for sampling.

  • min_p – Minimum probability threshold values.

  • top_p – Top-p (nucleus) sampling values.

  • top_k – Top-k sampling values.

  • num_reqs – Actual number of requests.

  • padded_num_reqs – Target padded number of requests.

Returns

  • temperature: Padded with -1.0 (float32)

  • min_p: Padded with 0.0 (float32)

  • top_p: Padded with 1.0 (float32)

  • top_k: Padded with 0 (int32)

Return type

A tuple of padded arrays

Note

Default padding values are chosen to be neutral for sampling operations.

easydel.inference.esurge.runners.sequence_buffer.fill_slice(arr, fill_val, num_reqs, padded_num_reqs)[source]#

Fill array slice with padding value.

Parameters
  • arr – Input array to pad.

  • fill_val – Value to use for padding.

  • num_reqs – Number of valid requests.

  • padded_num_reqs – Target padded size.

Returns

Array with padding applied from num_reqs to padded_num_reqs.

easydel.inference.esurge.runners.sequence_buffer.move_row(arr, from_idx, to_idx)[source]#

Move a row from one index to another.

Parameters
  • arr – Input array.

  • from_idx – Source row index.

  • to_idx – Destination row index.

Returns

Array with row moved from from_idx to to_idx.

Note

Works for both NumPy ndarrays and JAX arrays. Always returns a new array.

easydel.inference.esurge.runners.sequence_buffer.pack_prompts(token_ids, num_prompt_tokens, padded_num_reqs, padded_prompt_len, pad_id)[source]#

Pack prompt tokens into a padded tensor.

Creates a padded tensor of prompt tokens with consistent shape for batch processing. Tokens beyond the prompt length are replaced with the padding ID.

Parameters
  • token_ids – Token IDs array of shape [max_num_reqs, max_model_len].

  • num_prompt_tokens – Number of prompt tokens per request [max_num_reqs].

  • padded_num_reqs – Target number of requests after padding.

  • padded_prompt_len – Target prompt length after padding.

  • pad_id – Token ID to use for padding.

Returns

Packed prompts array of shape [padded_num_reqs, padded_prompt_len] with valid tokens and padding.

Note

This function is JIT-compiled with static arguments for padded dimensions to enable efficient compilation caching.

easydel.inference.esurge.runners.sequence_buffer.swap_rows(arr, i1, i2)[source]#

Swap two rows in an array.

Parameters
  • arr – Input array to swap rows in.

  • i1 – Index of first row.

  • i2 – Index of second row.

Returns

Array with rows i1 and i2 swapped.

Note

Works for both NumPy ndarrays and JAX arrays. Always returns a new array.

easydel.inference.esurge.runners.sequence_buffer.swap_rows_pytree(arrs, i1, i2)[source]#

Swap rows across all arrays in a pytree.

Parameters
  • arrs – PyTree containing arrays.

  • i1 – Index of first row to swap.

  • i2 – Index of second row to swap.

Returns

PyTree with same structure but rows swapped in all arrays.