easydel.inference.esurge.runners.execution_manager#
Execution manager for high-performance model inference with fused step functions.
This module implements the ExecutionManager class, which handles compilation, caching, and execution of fused inference steps. The manager pre-compiles functions for multiple input configurations to eliminate runtime compilation overhead during serving.
- Architecture:
The manager uses a fused execution model where a single JIT-compiled function combines four sequential operations:
Input preparation: Token gathering and position calculation
Model forward pass: Transformer execution with paged attention
Token sampling: Stochastic sampling with temperature/top-k/top-p
State updates: Token buffer updates and sequence tracking
This fusion minimizes host-device communication (single dispatch per step) and maximizes kernel fusion opportunities within JAX/XLA.
- Compilation Modes:
AOT (Ahead-of-Time): Pre-compiles all configurations using lower().compile() for predictable latency and minimal warmup. Default for production.
JIT (Just-in-Time): Defers compilation to first execution. Faster initial setup but unpredictable first-step latency.
- Performance Characteristics:
Single host-device round-trip per inference step
Automatic kernel fusion via XLA compiler
Bucketed compilation: O(log N) unique compilations for N request sizes
LRU cache with capacity of 64 compiled variants
Example
>>> from easydel.inference.esurge.runners import ExecutionManager
>>> executor = ExecutionManager(
... model=model,
... mesh=jax.sharding.Mesh(devices, ('dp', 'tp')),
... kv_pages=cache,
... use_aot_forward=True,
... )
>>> executor.compile(
... num_tokens_paddings=[128, 256, 512, 1024],
... num_reqs_max_model_len=16,
... max_pages_per_req=64,
... max_num_reqs=32,
... metadata=cache_metadata,
... )
>>> result = executor.execute(
... num_tokens=256,
... device_state=state,
... scheduled_full=scheduled,
... req_num_tokens_full=req_tokens,
... active_mask_full=active_mask,
... input_ids_buf=input_buf,
... position_ids_buf=pos_buf,
... padded_num_reqs=16,
... )
- class easydel.inference.esurge.runners.execution_manager.ExecutionManager(model: EasyDeLBaseModule, use_aot_forward: bool = True, min_input_pad: int = 8, max_model_len: int = 8192, max_num_reqs: int = 16, max_num_tokens: int | None = None, metadata: RaggedPagesCacheView = None, verbose: bool = False)[source]#
Bases:
objectCompilation and execution manager for fused inference step functions.
The ExecutionManager pre-compiles and caches fused step functions for multiple input configurations, enabling low-latency serving without runtime compilation. It uses bucketed compilation (powers of 2) to reduce the number of unique variants while maintaining good hardware utilization.
- Architecture:
The manager splits the model into (graphdef, graphstate, graphother) for efficient functional transformations. The graphstate (weights) can be updated without recompilation. Compiled functions are cached in an LRU structure with 64-entry capacity.
- Compilation Strategy:
Request counts are bucketed into powers of 2 (up to min_input_pad, then nearest power of 2 above). Token counts use explicit padding values provided during compile(). This produces O(log N * M) compilations for N request sizes and M token configurations.
- model#
EasyDeL model instance (EasyDeLBaseModule).
- mesh#
JAX sharding mesh for distributed execution across devices.
- kv_pages#
Paged KV cache storage (RaggedPagesCache).
- use_aot_forward#
If True, use AOT compilation via lower().compile(). If False, use JIT compilation on first call. Default: True.
- min_input_pad#
Minimum request count padding for bucketing. Default: 8.
- max_model_len#
Maximum sequence length supported by model.
- max_num_reqs#
Maximum concurrent requests.
- max_num_tokens#
Maximum tokens per batch (defaults to max_model_len).
- metadata#
KV cache metadata (RaggedPagesCacheView).
- graphdef#
Model graph definition (static structure).
- graphstate#
Model graph state (weights, device-resident).
- graphother#
Auxiliary model state (buffers, etc.).
- rng_key#
JAX random key for sampling, threaded through steps.
- Private Attributes:
_model_step_fn: Model-only forward function (ejit-decorated). _sampling_fn: Sampler/update function (ejit-decorated). _model_lowerd_history: OrderedDict LRU cache of compiled model functions. _sampler_lowerd_history: OrderedDict cache for compiled sampler function. _cache_capacity: Maximum cache entries (64). _debug_baselines: Hash baselines for debugging recompilations. _empty_sharding: Default sharding (replicated across mesh).
Example
>>> # Initialize manager >>> executor = ExecutionManager( ... model=model, ... kv_pages=cache, ... use_aot_forward=True, ... min_input_pad=8, ... max_model_len=8192, ... max_num_reqs=32, ... ) >>> >>> # Pre-compile for expected configurations >>> executor.compile( ... num_tokens_paddings=[128, 256, 512, 1024, 2048], ... num_reqs_max_model_len=16, ... max_pages_per_req=128, ... max_num_reqs=32, ... metadata=cache.metadata, ... ) >>> >>> # Execute steps during serving >>> results = executor.execute( ... num_tokens=512, ... device_state=state, ... scheduled_full=scheduled, ... req_num_tokens_full=req_tokens, ... active_mask_full=active, ... input_ids_buf=input_buf, ... position_ids_buf=pos_buf, ... padded_num_reqs=16, ... )
- compile(num_tokens_paddings: list[int], num_reqs_max_model_len: int, max_pages_per_req: int, max_num_reqs: int, metadata: RaggedPagesCacheView, num_reqs_paddings: list[int] | None = None) None[source]#
Compile model execution functions for various input configurations.
Pre-compiles functions for different combinations of token counts and request counts to avoid runtime compilation overhead. This enables seamless switching between different batch sizes during inference.
- Parameters
num_tokens_paddings – List of token count configurations to compile.
num_reqs_max_model_len – Maximum number of requests at max model length.
max_pages_per_req – Maximum number of KV cache pages per request.
max_num_reqs – Maximum number of concurrent requests.
metadata – Pages cache metadata containing configuration details.
Note
Compilation progress is logged using a progress bar. The total number of compilations is len(num_tokens_paddings) * number of unique padded request counts.
Example
>>> executor.compile( ... num_tokens_paddings=[128, 256, 512, 1024], ... num_reqs_max_model_len=16, ... max_pages_per_req=64, ... max_num_reqs=32, ... metadata=cache_metadata ... )
- execute(num_tokens: int, scheduled_full_cpu: ndarray, req_num_tokens_full: Array, active_mask_full_cpu: ndarray, input_ids_buf: Array, position_ids_buf: Array, padded_num_reqs: int, token_ids_cpu: ndarray, num_computed_tokens_cpu: ndarray, temperature_cpu: ndarray, top_p_cpu: ndarray, top_k_cpu: ndarray, min_p_cpu: ndarray, page_table_cpu: ndarray) tuple[easydel.inference.esurge.runners.execution_types.MinimalDeviceState, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array][source]#
Execute a single fused inference step.
Runs a pre-compiled fused function that combines input preparation, model forward pass, token sampling, and state updates in a single device dispatch.
- Parameters
num_tokens – Total tokens to process across all requests in this step. Must match a value from num_tokens_paddings used during compile().
device_state – Current device-side sequence state (DeviceSequenceState). Contains token buffers, position tracking, and sampling parameters.
scheduled_full – Number of tokens scheduled per request [max_num_reqs]. Determines how many tokens from each request enter this step.
req_num_tokens_full – Target token count per request [max_num_reqs]. Used to determine when requests have generated enough tokens.
active_mask_full – Boolean mask for active requests [max_num_reqs]. Inactive requests are skipped during processing.
input_ids_buf – Contiguous token ID buffer [max_num_tokens]. Flattened across requests for efficient batch processing.
position_ids_buf – Contiguous position ID buffer [max_num_tokens]. Parallel to input_ids_buf with position indices.
padded_num_reqs – Bucketed request count for compilation lookup. Must be a power of 2 (or min_input_pad) matching a compiled variant.
- Returns
device_state: Updated sequence state with new tokens written.
out_tokens_full: Generated tokens [max_num_reqs], -1 for invalid.
valid_mask_full: Boolean mask for valid generations [max_num_reqs].
input_ids_buf: Updated input buffer (may contain new tokens).
position_ids_buf: Updated position buffer.
query_start_loc_buf: Query start locations [max_num_reqs+1].
seq_lens_buf: Sequence lengths [max_num_reqs].
pages_tables_buf: Page tables [num_reqs, max_pages].
hidden_states: Last layer hidden states [num_tokens, hidden_dim].
logits: Output logits [padded_num_reqs, vocab_size].
- Return type
Tuple of 10 elements
- Raises
KeyError – If no compiled function exists for (num_tokens, padded_num_reqs). This indicates the configuration wasn’t included in compile() call.
Note
The KV cache (self.kv_pages) and random key (self.rng_key) are updated in-place on self after execution completes.
Example
>>> results = executor.execute( ... num_tokens=256, ... device_state=state, ... scheduled_full=jnp.array([4, 8, 2, ...]), ... req_num_tokens_full=jnp.array([512, 256, 128, ...]), ... active_mask_full=jnp.array([True, True, False, ...]), ... input_ids_buf=input_buf, ... position_ids_buf=pos_buf, ... padded_num_reqs=16, ... ) >>> new_state, tokens, valid, *rest = results
- get_async_prep_result() tuple[easydel.inference.esurge.runners.execution_types.BatchMetadata, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array] | None[source]#
Get the result of a previously started async prep.
Returns None if no async prep is pending. Otherwise waits for the transfer to complete and returns the same tuple as prepare_batch_metadata().
- get_compile_configurations(kv_pages: RaggedPagesCache, rng_key: PRNGKey, num_tokens: int, num_reqs_max_model_len: int, max_pages_per_req: int, max_num_reqs: int, padded_num_reqs: int, metadata: RaggedPagesCacheView)[source]#
Generate compilation arguments for step function.
Creates dummy input structures with correct shapes, dtypes, and shardings for tracing the step function during AOT/JIT compilation. All arrays are device-resident with appropriate sharding annotations to prevent XLA from generating multiple compilation variants.
- Parameters
kv_pages – KV cache pages (used as-is in compilation args).
rng_key – Random key for sampling (device-placed with empty sharding).
num_tokens – Token count (unused, for API compatibility).
num_reqs_max_model_len – Max requests at model length (unused).
max_pages_per_req – Max pages per request (unused).
max_num_reqs – Maximum concurrent requests for buffer sizing.
padded_num_reqs – Target padded request count for this compilation variant.
metadata – KV cache metadata for buffer initialization.
- Returns
[graphdef, graphstate, graphother, inputs] where inputs is a StepFunctionInputs PyTree with dummy values.
- Return type
List of compilation arguments
Note
Dummy values use simple patterns (ones, zeros) since compilation only traces shapes/dtypes. The returned structures must match runtime shardings exactly to avoid recompilation.
- get_compiled_key(num_tokens: int, padded_num_reqs: int)[source]#
Retrieve pre-compiled model step function for given input dimensions.
- Parameters
num_tokens – Number of tokens in the input batch.
padded_num_reqs – Padded number of requests for batching.
- Returns
Compiled fused step function for the specified number of tokens.
- get_model_step_fn() Callable[source]#
Create the model-only ejit that consumes precomputed metadata.
- init_fns() None[source]#
Initialize the fused step execution function.
Initializes the model-only execution function. Sampling/state updates are handled by a separate ejit generated during initialization.
Note
Called automatically during initialization. Should not be called directly by users.
- property maybe_implicit#
- prepare_batch_metadata(num_tokens_static: int, scheduled_full_cpu: ndarray, active_mask_full_cpu: ndarray, input_ids_buf: Array, position_ids_buf: Array, token_ids_cpu: ndarray, num_computed_tokens_cpu: ndarray, temperature_cpu: ndarray, top_p_cpu: ndarray, top_k_cpu: ndarray, min_p_cpu: ndarray, page_table_cpu: ndarray, padded_num_reqs_in: int) tuple[easydel.inference.esurge.runners.execution_types.BatchMetadata, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array][source]#
Precompute batch metadata using CPU-first approach.
Performs all metadata computation on CPU using NumPy for speed, then transfers to device in a single operation.
- Parameters
num_tokens_static – Number of tokens to process
device_state – Device sequence state (for page tables and sampling params)
scheduled_full – Tokens scheduled per request
active_mask_full – Active request mask
input_ids_buf – Input buffer (will be replaced)
position_ids_buf – Position buffer (will be replaced)
token_ids_cpu – NumPy array of token IDs [max_num_reqs, max_model_len]
num_computed_tokens_cpu – NumPy array of computed tokens [max_num_reqs]
padded_num_reqs_in – Caller-selected padded request bucket
- Returns
Tuple of (BatchMetadata, input_ids_buf, position_ids_buf)
- start_async_prep(num_tokens_static: int, scheduled_full_cpu: ndarray, active_mask_full_cpu: ndarray, input_ids_buf: Array, position_ids_buf: Array, token_ids_cpu: ndarray, num_computed_tokens_cpu: ndarray, temperature_cpu: ndarray, top_p_cpu: ndarray, top_k_cpu: ndarray, min_p_cpu: ndarray, page_table_cpu: ndarray, padded_num_reqs_in: int) None[source]#
Start async device transfer for the next batch (double buffering).
This method performs CPU-side preparation and initiates an async device transfer without blocking. Call get_async_prep_result() to retrieve the transferred arrays when needed.
The device transfer runs in parallel with model execution, hiding the ~2ms transfer latency.
- update_graphs(model: EasyDeLBaseModule | None = None, *, graphdef=None, graphstate=None, graphother=None) None[source]#
Update the graph components (weights) used by the fused executor.
- Parameters
model – Optional EasyDeL module to source new graph parts from. When provided, graphdef/graphstate/graphother are pulled from this model unless explicitly overridden via the keyword arguments.
graphdef – Optional graph definition replacement.
graphstate – Optional graph state replacement (typically the weights).
graphother – Optional auxiliary graph data replacement.
- Raises
ValueError – If neither a model nor explicit graph components are provided.