easydel.inference.__init__

Contents

easydel.inference.__init__#

class easydel.inference.__init__.SamplingParams(max_tokens: int = 16, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.7, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, suppress_tokens: list[int] = <factory>)[source]#

Bases: object

Parameters controlling the sampling process during text generation.

max_tokens#

The maximum number of tokens to generate (excluding the prompt). Defaults to 16.

Type

int

presence_penalty#

Penalty applied to the logits of tokens already present in the generated sequence. Positive values discourage repetition. Defaults to 0.0.

Type

float

frequency_penalty#

Penalty applied to the logits of tokens based on their frequency in the generated sequence so far. Positive values discourage verbatim repetition. Defaults to 0.0.

Type

float

repetition_penalty#

Multiplicative penalty applied to the logits of previously seen tokens. Values > 1.0 discourage repetition, < 1.0 encourage it. Defaults to 1.0.

Type

float

temperature#

Controls the randomness of the sampling. Higher values (e.g., > 1.0) make the distribution flatter (more random), lower values (e.g., < 1.0) make it peakier (more deterministic). A value of 0.0 effectively becomes greedy sampling. Defaults to 0.0.

Type

float

top_p#

Nucleus sampling threshold. If set to a value < 1.0, only the most probable tokens with a cumulative probability exceeding top_p are considered for sampling. Defaults to 1.0 (no nucleus sampling).

Type

float

top_k#

Top-k sampling threshold. If set to a value > 0, only the top_k most probable tokens are considered for sampling. Defaults to 0 (no top-k sampling).

Type

int

min_p#

Minimum probability threshold. Filters out tokens with probability less than min_p. Defaults to 0.0 (no minimum probability filtering).

Type

float

suppress_tokens#

A list of token IDs that should be completely suppressed (their logits set to -inf) during generation. Defaults to an empty list.

Type

list[int]

frequency_penalty: float = 0.0#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_logits_processor()[source]#

Constructs a LogitsProcessorList containing the configured logits processors.

Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.

Returns

A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.

get_logits_warper()[source]#

Constructs a LogitsProcessorList containing the configured logits warpers.

Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.

Returns

A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.

property logits_processor#
property logits_warper#
max_tokens: int = 16#
min_p: float = 0.0#
presence_penalty: float = 0.0#
repetition_penalty: float = 1.0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

suppress_tokens: list[int]#
temperature: float = 0.7#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: int = 0#
top_p: float = 1.0#
class easydel.inference.__init__.oDriver(engine: oEngine)[source]#

Bases: AbstractDriver

oDriver is responsible for managing the inference process for the oEngine. It handles request submission, input preparation, inference execution, and processing of model outputs. It utilizes background threads for concurrent processing of different stages of the inference pipeline.

compile()[source]#

Compiles the underlying engines.

This method is intended to perform any necessary compilation steps for the inference engines. Currently, it’s a placeholder.

property driver_name#

Returns the name of the driver, derived from the engine’s model name.

get_total_concurrent_requests() int[source]#

Gets the total number of concurrent requests the driver can handle.

This is determined by the maximum number of concurrent decode operations supported by the engine.

Returns

The maximum number of concurrent requests.

Return type

int

property num_used_slots#
place_request_on_prefill_queue(request: ActiveRequest)[source]#

Places a new request onto the prefill queue for processing.

This method is used internally to add requests that require prefilling and subsequent generation.

Parameters

request (ActiveRequest) – The active request to place on the queue.

property processor: Any#

Returns the processor/tokenizer associated with the engines.

Assumes all engines (prefill and decode) use the same processor. Raises an error if no engines are configured.

start()[source]#

Starts the driver and its background processing threads.

Threads for input preparation, inference, and summary processing are started if the driver is not already live.

stop()[source]#

Stops the driver and all background threads gracefully.

Signals the background threads to exit by putting None into their respective queues and then waits for them to join.

submit_request(request: Any)[source]#

Submits a new request to the driver’s processing pipeline.

The request is placed on the prefill queue for initial processing.

Parameters

request (tp.Any) – The request object to submit. Must be of type ActiveRequest.

Raises

TypeError – If the submitted request is not an instance of ActiveRequest.

class easydel.inference.__init__.oEngine(model: Any, processor: Any, storage: PagedAttentionCache, manager: HBMPageManager, max_concurrent_decodes: int | None = None, max_concurrent_prefill: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, batch_size: int | None = None, seed: int = 894)[source]#

Bases: AbstractInferenceEngine

Optimized inference engine for EasyDeL models using Paged Attention.

This engine manages the inference process, including KV cache management with Paged Attention, scheduling, and execution of model forward passes for both prefill and decode steps.

property batch_size: int | None#

Returns the configured batch size for the engine, if specified.

bulk_insert(prefix: Any, decode_state: Any, slots: list[int]) Any[source]#

Efficiently inserts multiple prefill results into the decode state.

This method is optimized for integrating the results of a batch prefill operation into the decode state for multiple sequence slots.

Parameters
  • prefix – The generation state from the bulk prefill step.

  • decode_state – The current decode state.

  • slots – A list of slot indices to insert into.

Returns

The updated decode state.

property colocated_cpus: Optional[list[jaxlib.xla_extension.Device]]#

Returns CPU devices colocated with the engine’s accelerator devices.

This information can be useful for optimizing data transfers between host (CPU) and accelerator (GPU/TPU) memory. Currently returns None as the implementation is pending.

Returns

A list of colocated JAX CPU devices, or None if not implemented or available.

decode(graphstate: State[Key, VariableState[Any]], graphothers: State[Key, VariableState[Any]], state: Any, rngs: PRNGKey) tuple[Any, Any][source]#

Performs a single decode step for active sequences.

This involves generating the next token for each sequence based on the current state and KV cache.

Parameters
  • graphstate – The graph state of the model.

  • graphothers – Other graph components of the model.

  • state – The current generation state.

  • rngs – The PRNG key for sampling.

Returns

A tuple containing the updated generation state and the generated result tokens.

property eos_token_ids: list[int]#

Returns a list of end-of-sequence token IDs from the processor and model config.

forward(graphstate: State[Key, VariableState[Any]], graphothers: State[Key, VariableState[Any]], state: ActiveSequenceBatch, iteration_plan: NextIterationPlan) ModelOutputBatch[source]#

Performs a forward pass of the model.

This method executes the compiled continuous_forward function, processing the input state and iteration plan to produce model outputs and update the KV cache storage.

Parameters
  • graphstate – The graph state of the model.

  • graphothers – Other graph components of the model.

  • state – The current active sequence batch state.

  • iteration_plan – The plan for the current inference iteration.

Returns

The output batch from the model.

free_resource(slot: int) bool[source]#

Frees resources associated with a specific inference slot. (Not Implemented)

Parameters

slot – The index of the slot to free.

Returns

Always returns False as it’s not implemented.

get_prefix_destination_sharding() Any[source]#

Returns the shardings necessary to transfer KV cache data between engines.

This method is intended for scenarios involving multiple engines or devices where KV cache data needs to be moved.

Returns

The sharding specification for prefix destinations.

get_state_shardings(is_decode: bool = False)[source]#

Returns the sharding specifications for the engine’s state.

Parameters

is_decode – A boolean indicating if the sharding is for the decode state.

Returns

A tuple representing the sharding specification.

init_decode_state(*args, **kwargs) ActiveSequenceBatch[source]#

Initializes the decode state for active sequences.

Parameters
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

An initialized ActiveSequenceBatch instance.

insert(prefix: Any, decode_state: Any, slot: int) Any[source]#

Inserts or updates a generation state for a specific slot.

This is typically used to integrate the results of a prefill step into the ongoing decode process for a particular sequence slot.

Parameters
  • prefix – The generation state from the prefill step.

  • decode_state – The current decode state.

  • slot – The slot index to insert into.

Returns

The updated decode state.

property max_concurrent_decodes#

Maximum number of sequences decoded concurrently.

property max_concurrent_prefill#
property max_length: int#

Maximum total sequence length (prompt + generation).

This defines the size of the KV cache allocated.

property max_prefill_length: int#

Maximum allowed length for the initial prompt (prefill phase).

Prompts longer than this will be truncated or handled according to the padding/truncation logic.

property pad_token_id#

Returns the pad token ID from the processor.

prefill(graphstate: State[Key, VariableState[Any]], graphothers: State[Key, VariableState[Any]], tokens: Array, valids: Array, true_length: int, temperature: Array, top_p: Array, top_k: Array, rngs: PRNGKey) tuple[Any, Any][source]#

Performs the prefill step for a batch of prompts.

This involves processing the initial prompt tokens and populating the KV cache.

Parameters
  • graphstate – The graph state of the model.

  • graphothers – Other graph components of the model.

  • tokens – The input tokens for the prompts.

  • valids – A boolean array indicating valid tokens.

  • true_length – The true length of the sequences.

  • temperature – The temperature for sampling.

  • top_p – The top-p value for sampling.

  • top_k – The top-k value for sampling.

  • rngs – The PRNG key for sampling.

Returns

A tuple containing the generation state after prefill and the result tokens.

property prefill_lengths: list[int]#

Returns the configured list of max prefill length buckets for the engine.

property prng_key: PRNGKey#

Provides a new PRNG key split from the internal state for sampling.

Each call to this property consumes the current key and returns a new, unique key, ensuring that subsequent sampling operations use different randomness.

Returns

A new JAX PRNGKey.

property samples_per_slot: int#

Number of samples generated per inference slot.

This determines how many independent generation results are produced for each logical slot managed by the engine. It’s often 1, but could be higher for techniques like parallel sampling.

class easydel.inference.__init__.vDriver(prefill_engines: Optional[Union[list[easydel.inference.vsurge.engines.vengine.engine.vEngine], vEngine]] = None, decode_engines: Optional[Union[list[easydel.inference.vsurge.engines.vengine.engine.vEngine], vEngine]] = None, interleaved_mode: bool = False, detokenizing_blocks: int = 8)[source]#

Bases: AbstractDriver

Drives the engines.

compile()[source]#

Compiles engines.

property driver_name#
get_total_concurrent_requests() int[source]#

Gets the total number of concurrent requests the driver can handle.

place_request_on_prefill_queue(request: ActiveRequest)[source]#

Used to place new requests for prefilling and generation.

property processor: Any#

Returns the processor/tokenizer associated with the engines.

Assumes all engines (prefill and decode) use the same processor. Raises an error if no engines are configured.

start()[source]#
stop()[source]#

Stops the driver and all background threads.

submit_request(request: Any)[source]#

Submits a new request to the driver’s processing queue.

class easydel.inference.__init__.vEngine(model: Any, processor: Any, max_concurrent_decodes: int | None = None, max_concurrent_prefill: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, batch_size: int | None = None, seed: int = 894)[source]#

Bases: AbstractInferenceEngine

Core inference engine for EasyDeL models using NNX graphs.

This engine manages the model state (split into graph definition, state, and other parameters) and provides JIT-compiled functions for the prefill and decode steps of autoregressive generation. It handles KV caching and sampling.

property batch_size: int | None#

Returns the configured batch size for the engine, if specified.

bulk_insert(prefix: GenerationState, decode_state: GenerationState, slots: list[int]) GenerationState[source]#

Efficiently inserts multiple prefill results into the decode state.

This function takes a GenerationState (prefix) typically resulting from a batch prefill operation and inserts its relevant components (logits, cache, index, tokens, valids, position IDs, generated tokens) into the main decode_state at multiple specified slots. This is useful for initializing the decode state after processing a batch of prompts simultaneously. Both input states’ caches are donated.

Parameters
  • prefix – The GenerationState containing the results from a prefill operation (or similar initialization). Its cache is marked for donation.

  • decode_state – The target GenerationState (e.g., the main decode loop state) to be updated. Its cache is marked for donation.

  • slots – A list of integer indices indicating the slots within the decode_state’s batch dimension where the corresponding data from the prefix state should be inserted.

Returns

An updated GenerationState (decode_state) with the prefill results inserted at the specified slots.

property colocated_cpus: Optional[list[jaxlib.xla_extension.Device]]#

Returns CPU devices colocated with the engine’s accelerator devices.

This information can be useful for optimizing data transfers between host (CPU) and accelerator (GPU/TPU) memory. Currently returns None as the implementation is pending.

Returns

A list of colocated JAX CPU devices, or None if not implemented or available.

decode(graphstate: State[Key, VariableState[Any]], graphothers: State[Key, VariableState[Any]], state: GenerationState, rngs: PRNGKey) tuple[easydel.inference.vsurge.engines.vengine.utilities.GenerationState, easydel.inference.vsurge.engines._utils.ResultTokens][source]#

Performs a single decode step in the autoregressive generation loop.

Takes the previous GenerationState, generates the next token using the model and KV cache, and updates the state. This function is JIT-compiled and allows the input state’s cache to be modified in-place (donated).

Parameters
  • graphstate – The NNX GraphState (parameters) of the model.

  • graphothers – Other NNX state variables of the model.

  • state – The current GenerationState from the previous step. state.cache is marked for donation.

  • rngs – A JAX PRNG key for sampling the next token.

Returns

  • next_generation_state: The updated GenerationState for the next iteration.

  • result: A ResultTokens object containing the newly generated token.

Return type

A tuple containing

property eos_token_ids: list[int]#

A list of End-of-Sequence token IDs.

free_resource(slot: int) bool[source]#

Frees resources associated with a specific inference slot. (Not Implemented)

Parameters

slot – The index of the slot to free.

Returns

Always returns False as it’s not implemented.

get_prefix_destination_sharding() Any[source]#

Returns the shardings necessary to transfer KV cache data between engines.

Currently returns None, indicating default or no specific sharding.

get_state_shardings(is_decode: bool = False)[source]#
init_decode_state(*args, **kwargs) GenerationState[source]#

Initializes the GenerationState for a new sequence

insert(prefix: GenerationState, decode_state: GenerationState, slot: int) GenerationState[source]#

Inserts or updates a generation state, potentially for managing batches. (JIT-compiled)

This function seems designed to merge or update parts of the generation state, possibly inserting a ‘prefix’ state (e.g., from a completed prefill) into a larger batch state (‘decode_state’) at a specific ‘slot’. The exact mechanism for insertion isn’t fully clear from the current implementation, as it primarily focuses on broadcasting the prefix cache and returning the prefix state. Both input states’ caches are donated.

Parameters
  • prefix – The GenerationState to potentially insert (e.g., from prefill). Its cache is marked for donation.

  • decode_state – The target GenerationState to update (e.g., the main decode loop state). Its cache is marked for donation.

  • slot – The index within the batch where the insertion/update should occur.

Returns

An updated GenerationState. In the current implementation, it returns the prefix state with its cache potentially broadcasted. Needs clarification on the intended merging logic with decode_state and slot.

property max_concurrent_decodes: int#

Maximum number of sequences that can be decoded concurrently.

This determines the batch size used during the decode phase.

property max_length: int#

Maximum total sequence length (prompt + generation).

This defines the size of the KV cache allocated.

property max_prefill_length: int#

Maximum allowed length for the initial prompt (prefill phase).

Prompts longer than this will be truncated or handled according to the padding/truncation logic.

property pad_token_id#

The ID of the padding token.

prefill(graphstate: State[Key, VariableState[Any]], graphothers: State[Key, VariableState[Any]], tokens: Array, valids: Array, true_length: int, temperature: Array, top_p: Array, rngs: PRNGKey) tuple[easydel.inference.vsurge.engines.vengine.utilities.GenerationState, easydel.inference.vsurge.engines._utils.ResultTokens][source]#

Performs the prefill step for initializing the generation process.

Processes the initial prompt tokens, initializes the KV cache, and generates the first token of the sequence. This function is JIT-compiled.

Parameters
  • graphstate – The NNX GraphState (parameters) of the model.

  • graphothers – Other NNX state variables of the model.

  • tokens – The input prompt token IDs (batch_size, sequence_length).

  • valids – A boolean array indicating valid token positions in the input (batch_size, sequence_length or batch_size, max_length).

  • rngs – A JAX PRNG key for sampling the first token.

Returns

  • generation_state: The initial GenerationState for the decode loop.

  • result: A ResultTokens object containing the first generated token.

Return type

A tuple containing

property prefill_lengths: list[int]#

Returns the configured list of max prefill length buckets for the engine.

property prng_key: PRNGKey#

Provides a new PRNG key split from the internal state for sampling.

Each call to this property consumes the current key and returns a new, unique key, ensuring that subsequent sampling operations use different randomness.

Returns

A new JAX PRNGKey.

property samples_per_slot: int#

Number of samples generated per inference slot.

This determines how many independent generation results are produced for each logical slot managed by the engine. It’s often 1, but could be higher for techniques like parallel sampling.

class easydel.inference.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#

Bases: object

Class for performing text generation using a pre-trained language graphdef in EasyDeL.

This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.

property SEQUENCE_DIM_MAPPING#
adjust_kwargs(input_ids: Array, attention_mask: Optional[Array] = None, **model_kwargs)[source]#
count_tokens(messages: List[Dict[str, str]], oai_like: bool = False)[source]#
count_tokens(text: str, oai_like: bool = False)
execute_decode(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
execute_prefill(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#

Executes a single generation step with performance monitoring.

generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, sampling_params: Optional[SamplingParams] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#

Generates text in streaming chunks with comprehensive input adjustment.

Parameters
  • input_ids – Input token IDs as a JAX array

  • attention_mask – Optional attention mask for the input

  • graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.

  • graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.

  • **model_kwargs – Additional model-specific keyword arguments

Returns

Generator yielding SampleState objects containing generation results and metrics

property inference_name#
classmethod load_inference(path: Union[PathLike, str], model: None, processor_class: None)[source]#
property metrics#
property model#
property model_prefill_length: int#

Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.

Returns

The maximum length available for input prefill

Return type

int

Raises

ValueError – If no maximum sequence length configuration is found

precompile(config: vInferencePreCompileConfig)[source]#

Precompiles the generation functions for a given batch size and input length.

This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.

Returns

True if precompilation was successful, False otherwise.

Return type

bool

process_prompt(prompt: Union[str, List[str], List[Dict[str, str]]], sampling_params: Union[SamplingParams, Dict], stream: bool = False) Union[PromptOutput, List[PromptOutput], Generator[str, None, SampleState], List[Generator[str, None, SampleState]]][source]#

Processes a prompt (string, list of strings, or OpenAI messages) and generates a response.

Parameters
  • prompt – The input prompt. Can be a single string, a list of strings (processed sequentially), or a list of dictionaries representing OpenAI-style chat messages (processed as a single batch).

  • sampling_params – Configuration for the generation process (temperature, top_p, etc.). Can be a SamplingParams object or a dictionary.

  • stream – If True, yields generated tokens incrementally. If False, returns the complete generation(s) at the end.

Returns

A PromptOutput object containing the full text and metrics. - If input is str or List[Dict] and stream=True: A generator yielding string chunks. The generator’s return value (accessible via try/except StopIteration) is the final SampleState. - If input is List[str] and stream=False: A list of PromptOutput objects. - If input is List[str] and stream=True: A list where each element is a generator as described above for the single stream case.

Return type

  • If input is str or List[Dict] and stream=False

Raises
  • TypeError – If the prompt format is invalid or processor does not support chat templates.

  • ValueError – If tokenization or processing fails.

async process_prompts_concurrently(prompts: List[Union[str, List[Dict[str, str]]]], max_concurrent_requests: int, sampling_params: Optional[SamplingParams] = None, stream: bool = False, progress_callback: Optional[Callable[[int, int], None]] = None) Union[List[PromptOutput], AsyncGenerator[Tuple[int, str, Any], None]][source]#

Processes a list of prompts concurrently, supporting both streaming and non-streaming modes.

Parameters
  • prompts – A list of prompts (strings or OpenAI-style message lists).

  • max_concurrent_requests – The maximum number of prompts to process in parallel. If <= 0, processing will be sequential (but still async).

  • sampling_params – Optional sampling parameters to override the default ones for this batch of requests. Passed to process_prompt.

  • stream – If True, returns an async generator yielding tuples of (index, type, data). If False, returns a list of PromptOutput objects.

  • progress_callback – An optional function called after each prompt finishes processing. Receives (completed_count, total_count).

Returns

A list of PromptOutput objects containing the full text, metrics,

or error information for each prompt, in the original order.

  • If stream=True: An async generator yielding tuples (index, type, data) where:
    • type is ‘text’ and data is the string chunk.

    • type is ‘error’ and data is the error string.

    • type is ‘final’ and data is the final PromptOutput object with metrics.

    The generator finishes when all prompts are processed.

Return type

  • If stream=False

Raises
  • ValueError – If input arguments are invalid.

  • RuntimeError – If an unexpected error occurs during processing.

save_inference(path: Union[PathLike, str])[source]#
property tokenizer#
class easydel.inference.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10, allow_parallel_workload: bool = False, oai_like_processor: bool = True)[source]#

Bases: object

FastAPI server for serving vInference instances.

This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.

async available_inference()[source]#

Lists available models (GET /v1/models).

async chat_completions(request: ChatCompletionRequest)[source]#

Handles chat completion requests (POST /v1/chat/completions).

Validates the request, retrieves the appropriate vInference model, tokenizes the input, and delegates to streaming or non-streaming handlers.

Parameters

request (ChatCompletionRequest) – The incoming request data.

Returns

The generated response, either

a complete JSON object or a streaming event-stream.

Return type

Union[JSONResponse, StreamingResponse]

async completions(request: CompletionRequest)[source]#

Handles completion requests (POST /v1/completions).

Processes the prompt for completion and returns generated text.

Parameters

request (CompletionRequest) – The incoming request data.

Returns

The generated response.

Return type

Union[JSONResponse, StreamingResponse]

async count_tokens(request: CountTokenRequest)[source]#

Token counting endpoint (POST /v1/count_tokens).

fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#

Starts the uvicorn server to run the FastAPI application.

Parameters
  • host (str) – The host address to bind to. Defaults to “0.0.0.0”.

  • port (int) – The port to listen on. Defaults to 11556.

  • metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.

  • log_level (str) – The logging level for uvicorn. Defaults to “info”.

  • ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.

  • ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.

async liveness()[source]#

Liveness check endpoint (GET /liveness).

async readiness()[source]#

Readiness check endpoint (GET /readiness).

class easydel.inference.__init__.vInferenceConfig(max_new_tokens: int = 64, streaming_chunks: int = 16, num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[PartitionAxis] = None, _loop_rows: Optional[int] = None, sampling_params: Optional[SamplingParams] = None)[source]#

Bases: object

Configuration class for the vInference engine, controlling the overall generation process.

This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings.

max_new_tokens#

The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64.

Type

int

streaming_chunks#

The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16.

Type

int

num_return_sequences#

The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from vInferencePreCompileConfig)

to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1.

Type

Optional[Union[int, Dict[int, int]]]

pad_token_id#

The token ID used for padding sequences. If None, the model’s default pad token ID might be used, or padding might not be applied.

Type

Optional[int]

bos_token_id#

The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic.

Type

Optional[int]

eos_token_id#

The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers.

Type

Optional[Union[int, List[int]]]

partition_rules#

A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model’s parameters and intermediate states (like attention cache). If None, default rules based on partition_axis are generated. Example: ((“.*kernel.*”, PartitionSpec(“fsdp”, None)), …)

Type

Optional[Tuple[Tuple[str, Any]]]

partition_axis#

A PartitionAxis object defining the logical names for sharding axes (e.g., ‘batch’, ‘sequence’, ‘head’). Required if partition_rules is None, used to generate default sharding rules.

Type

Optional[eformer.escale.partition.manager.PartitionAxis]

_loop_rows#

(Internal) The calculated number of iterations needed in the generation loop based on max_new_tokens and streaming_chunks. Automatically computed in __post_init__.

Type

Optional[int]

sampling_params#

A SamplingParams object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default SamplingParams instance with max_tokens set to max_new_tokens is created in __post_init__.

Type

Optional[easydel.inference.utilities.SamplingParams]

bos_token_id: Optional[int] = None#
eos_token_id: Optional[Union[int, List[int]]] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None) Tuple[Tuple[str, Any], ...][source]#

Generates or retrieves the sharding partition rules for the vInference engine.

If self.partition_rules is already set (custom rules provided), it returns them directly.

Otherwise, it constructs a default set of partition rules based on the axis names defined in self.partition_axis. These default rules aim to provide sensible sharding for common model components: - Input sequences (sequences, running_token) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats

(8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes.

  • Any parameters/states not matching the specific rules are replicated by default (.*).

Parameters

runtime_config – An optional vInferencePreCompileConfig. Currently unused in the default rule generation but available for potential customization in subclasses or future versions.

Returns

  • A regex pattern (string) matching parameter or state names.

  • A jax.sharding.PartitionSpec defining how the matched items should be sharded.

Return type

A tuple of partition rules. Each rule is a tuple containing

Raises

AssertionError – If self.partition_rules is None and self.partition_axis is also None, as axis names are required to generate default rules.

max_new_tokens: int = 64#
num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
pad_token_id: Optional[int] = None#
partition_axis: Optional[PartitionAxis] = None#
partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sampling_params: Optional[SamplingParams] = None#
streaming_chunks: int = 16#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Optional[Union[int, List[int]]] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#

Bases: object

Configuration class for pre-compiling vInference functions.

This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios.

batch_size#

Batch size or list of batch sizes for text generation.

Type

Union[int, List[int]]

prefill_length#

Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context.

Type

Optional[Union[int, List[int]]]

vision_included#

Whether vision inputs are included in the model.

Type

Union[bool, List[bool]]

vision_batch_size#

Batch size for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_channels#

Number of channels for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_height#

Height of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_width#

Width of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

required_props#

Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments).

Type

Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]]

batch_size: Union[int, List[int]] = 1#
classmethod create_optimized_compo(batch_size: Union[int, List[int]] = 1, max_prefill_length: int = 2048, min_prefill_length: int = 64, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#
extract() dict[source]#

Converts the configuration instance into a dictionary.

This method is useful for serialization or easily accessing all configuration values.

Returns

A dictionary representation of the vInferencePreCompileConfig instance.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_default_hash() int[source]#

Generates a unique integer hash representing the configuration.

This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration.

Returns

An integer hash value representing the configuration.

get_standalones() List[vInferencePreCompileConfig][source]#

Generates a list of standalone configurations from a potentially multi-value config.

If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual vInferencePreCompileConfig instances. Each resulting instance represents a single, specific compilation scenario.

If an attribute’s list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes.

If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance.

Returns

A list of vInferencePreCompileConfig instances, each representing a single, standalone compilation scenario.

prefill_length: Optional[Union[int, List[int]]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

vision_batch_size: Optional[Union[int, List[int]]] = None#
vision_channels: Optional[Union[int, List[int]]] = None#
vision_height: Optional[Union[int, List[int]]] = None#
vision_included: Union[bool, List[bool]] = False#
vision_width: Optional[Union[int, List[int]]] = None#
class easydel.inference.__init__.vSurge(driver: Union[vDriver, oDriver], vsurge_name: str | None = None)[source]#

Bases: object

Orchestrates the interaction between client requests and the vDriver.

compile()[source]#
async complete(request: vSurgeRequest) tp.AsyncGenerator[tp.List[ReturnSample]][source]#

Initiates and streams the results of a text completion request.

Creates an ActiveRequest using the plain prompt from the vSurgeRequest, places it on the driver’s prefill queue, and then asynchronously iterates through the results provided by the ActiveRequest’s return_channel.

It handles both client-side and server-side tokenization scenarios, buffering and processing results appropriately before yielding them.

Parameters

request – The vSurgeRequest containing the prompt and generation parameters.

Yields

Processed generation results, similar to the decode method. The format depends on the tokenization mode.

Raises

RuntimeError – If the prefill queue is full when trying to place the request.

count_tokens(text_or_conversation: Union[str, list]) int[source]#

Counts the number of tokens in a given string or conversation list.

Uses the underlying driver’s processor to tokenize the input and returns the count of tokens.

Parameters

text_or_conversation – Either a single string or a list of message dictionaries (like OpenAI chat format).

Returns

The total number of tokens in the input.

Raises

ValueError – If the input type is invalid or tokenization fails.

classmethod create_odriver(model: Any, processor: Any, storage: Optional[PagedAttentionCache] = None, manager: Optional[HBMPageManager] = None, page_size: int = 128, hbm_utilization: float = 0.6, max_concurrent_prefill: int | None = None, max_concurrent_decodes: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, seed: int = 894, vsurge_name: str | None = None) vSurge[source]#
classmethod create_vdriver(model: Any, processor: Any, max_concurrent_decodes: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, seed: int = 894, vsurge_name: str | None = None) vSurge[source]#

Creates a new instance of vSurge with configured vDriver and vEngines.

This class method provides a convenient way to instantiate the vSurge by setting up the necessary prefill and decode engines with the provided model, processor, and configuration parameters.

Parameters
  • model – The EasyDeLBaseModule instance representing the model.

  • processor – The tokenizer/processor instance.

  • max_concurrent_decodes – Maximum number of concurrent decode requests the decode engine can handle.

  • prefill_lengths – A list of prefill lengths to compile for the prefill engine.

  • max_prefill_length – The maximum prefill length for the prefill engine.

  • max_length – The maximum total sequence length for both engines.

  • seed – The random seed for reproducibility.

  • vsurge_name – An optional name for the vsurge.

Returns

A new instance of vSurge.

property driver#

Provides access to the underlying vDriver instance.

async generate(prompts: tp.Union[str, tp.Sequence[str]], sampling_params: tp.Optional[tp.Union[SamplingParams, tp.Sequence[SamplingParams]]] = None, stream: bool = False) tp.Union[tp.List[ReturnSample], tp.AsyncGenerator[tp.List[ReturnSample]]][source]#

Generates text completions concurrently for the given prompts.

Parameters
  • prompts – A single prompt string or a list of prompt strings.

  • sampling_params – A single SamplingParams object or a list of SamplingParams objects. If None, default SamplingParams will be used. If a single SamplingParams object is provided with multiple prompts, it will be applied to all prompts. If a list is provided, it must have the same length as the prompts list.

  • stream – If True, yields results (List[ReturnSample]) from any request as they become available. The list corresponds to one generation step from one request. If False, waits for all requests to complete and returns a list containing one aggregated ReturnSample per prompt.

Returns

An async generator yielding lists of ReturnSample as

steps complete across concurrent requests.

If stream is False: A list of aggregated ReturnSample objects, one for

each input prompt, after all requests have finished.

Return type

If stream is True

Raises
  • ValueError – If the lengths of prompts and sampling_params lists mismatch.

  • RuntimeError – If the underlying driver’s queue is full.

process_client_side_tokenization_response(response: list[easydel.inference.vsurge.utils.ReturnSample])[source]#

Processes responses when tokenization is handled client-side.

In this case, the response items (ReturnSample) are typically yielded directly without further server-side processing like detokenization or buffering.

Parameters

response – A list of ReturnSample objects from a single generation step.

Returns

The input list of ReturnSample objects, unchanged.

process_server_side_tokenization_response(response: list[easydel.inference.vsurge.utils.ReturnSample], buffered_response_list: list[list[easydel.inference.vsurge.utils.ReturnSample]]) list[easydel.inference.vsurge.utils.ReturnSample][source]#

Processes responses when tokenization/detokenization is server-side.

Combines the text and token IDs from the current response and any buffered previous responses for each sample. It then uses the metrics (TPS, generated token count) from the latest response in the sequence for the final output.

Parameters
  • response – The list of ReturnSample objects from the current step.

  • buffered_response_list – A list containing lists of ReturnSample objects from previous steps that were buffered.

Returns

A list of tuples, where each tuple represents a completed sample and contains: (decoded_string, all_token_ids, latest_tps, latest_num_generated_tokens).

property processor: Any#

Returns the processor/tokenizer associated with the underlying driver.

should_buffer_response(response: list[easydel.inference.vsurge.utils.ReturnSample]) bool[source]#

Determines if a response needs buffering for server-side detokenization.

Buffering is needed if any sample in the response ends with a byte token (e.g., “<0xAB>”), as this indicates an incomplete multi-byte character that requires subsequent tokens for proper decoding.

Parameters

response – A list of ReturnSample objects from a single generation step.

Returns

True if buffering is required, False otherwise.

start()[source]#
stop()[source]#
property vsurge_name#
class easydel.inference.__init__.vSurgeApiServer(vsurge_map: Union[Dict[str, vSurge], vSurge] = None, max_workers: int = 10, oai_like_processor: bool = True)[source]#

Bases: object

FastAPI server for serving vEngine instances.

This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.

async available_inference()[source]#

Lists available models (GET /v1/models).

async chat_completions(request: ChatCompletionRequest)[source]#

Handles chat completion requests (POST /v1/chat/completions).

Validates the request, retrieves the appropriate vEngine model, tokenizes the input, and delegates to streaming or non-streaming handlers.

Parameters

request (ChatCompletionRequest) – The incoming request data.

Returns

The generated response, either

a complete JSON object or a streaming event-stream.

Return type

Union[JSONResponse, StreamingResponse]

async completions(request: CompletionRequest)[source]#

Handles completion requests (POST /v1/completions).

Processes the prompt for completion and returns generated text.

Parameters

request (CompletionRequest) – The incoming request data.

Returns

The generated response.

Return type

Union[JSONResponse, StreamingResponse]

async count_tokens(request: CountTokenRequest)[source]#

Token counting endpoint (POST /v1/count_tokens).

fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#

Starts the uvicorn server to run the FastAPI application.

Parameters
  • host (str) – The host address to bind to. Defaults to “0.0.0.0”.

  • port (int) – The port to listen on. Defaults to 11556.

  • metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.

  • log_level (str) – The logging level for uvicorn. Defaults to “info”.

  • ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.

  • ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.

async liveness()[source]#

Liveness check endpoint (GET /liveness).

async readiness()[source]#

Readiness check endpoint (GET /readiness).

class easydel.inference.__init__.vSurgeRequest(prompt: str, max_tokens: int, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, temperature: float = 0.7, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, metadata: easydel.inference.vsurge.vsurge.vSurgeMetadata | None = None, is_client_side_tokenization: bool = False)[source]#

Bases: object

Represents a request specifically for text completion.

frequency_penalty: float = 0.0#
classmethod from_sampling_params(prompt: str, sampling_params: SamplingParams)[source]#
is_client_side_tokenization: bool = False#
max_tokens: int#
metadata: easydel.inference.vsurge.vsurge.vSurgeMetadata | None = None#
min_p: float = 0.0#
presence_penalty: float = 0.0#
prompt: str#
repetition_penalty: float = 1.0#
temperature: float = 0.7#
top_k: int = 0#
top_p: float = 1.0#
class easydel.inference.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.inference.vwhisper.config.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#

Bases: object

Whisper inference pipeline for performing speech-to-text transcription or translation.

Parameters
  • model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.

  • tokenizer (WhisperTokenizer) – Tokenizer for Whisper.

  • processor (WhisperProcessor) – Processor for Whisper.

  • inference_config (vWhisperInferenceConfig, optional) – Inference configuration.

  • dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.

generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#

Transcribe or translate audio input.

Parameters
  • audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.

  • chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.

  • stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.

  • batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.

  • language (str, optional) – Language of the input audio. Defaults to the language in inference_config.

  • task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.

Returns

A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.

Return type

dict

class easydel.inference.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#

Bases: object

Configuration class for Whisper inference.

Parameters
  • batch_size (int, optional, defaults to 1) – Batch size used for inference.

  • max_length (int, optional) – Maximum sequence length for generation.

  • generation_config (transformers.GenerationConfig, optional) – Generation configuration object.

  • logits_processor (optional) – Not used.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.

  • task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).

  • language (str, optional) – Language of the input audio.

  • is_multilingual (bool, optional) – Whether the model is multilingual.

batch_size: Optional[int] = 1#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

generation_config: Optional[Any] = None#
is_multilingual = None#
language = None#
logits_processor = None#
max_length: Optional[int] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

return_timestamps = None#
task = None#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.