easydel.inference.vinference.__init__#
- class easydel.inference.vinference.__init__.PromptOutput(*, text: Optional[str] = None, generated_tokens: Optional[int] = None, tokens_per_second: Optional[float] = None, error: Optional[str] = None, finish_reason: Optional[str] = None)[source]#
Bases:
BaseModelStructure for holding the output of a processed prompt.
- error: Optional[str]#
- finish_reason: Optional[str]#
- generated_tokens: Optional[int]#
- model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- text: Optional[str]#
- tokens_per_second: Optional[float]#
- class easydel.inference.vinference.__init__.SampleState(current_length: Union[Array, NamedSharding], sequences: Union[Array, NamedSharding], running_token: Union[Array, NamedSharding], is_sequence_finished: Union[Array, NamedSharding], prng_key: Union[PRNGKey, NamedSharding], model_kwargs: Union[Dict[str, Array], NamedSharding], tokens_per_second: Optional[float] = -inf, generated_tokens: Optional[int] = 0, padded_length: Optional[int] = 0, _time_spent_computing: Optional[float] = 0.0, _compile_config: Optional[vInferencePreCompileConfig] = None)[source]#
Bases:
objectRepresents the state of the sampling process during token generation within the vInference engine.
This class encapsulates all necessary information to pause and resume the generation loop. It tracks the progress of generation, including the tokens generated so far, the current position, completion status, random number generator state, and any model-specific state (like attention caches).
- current_length#
The current length of the generated sequences (number of tokens generated so far).
- Type
Union[jax.Array, jaxlib.xla_extension.NamedSharding]
- sequences#
The tensor holding the generated token IDs for each sequence in the batch. Shape: (batch_size, max_sequence_length).
- Type
Union[jax.Array, jaxlib.xla_extension.NamedSharding]
- running_token#
The most recently generated token for each sequence. Used as input for the next step. Shape: (batch_size, 1).
- Type
Union[jax.Array, jaxlib.xla_extension.NamedSharding]
- is_sequence_finished#
A boolean tensor indicating whether each sequence in the batch has reached an end-of-sequence (EOS) token or the maximum generation length. Shape: (batch_size,).
- Type
Union[jax.Array, jaxlib.xla_extension.NamedSharding]
- prng_key#
The JAX pseudo-random number generator key used for stochastic sampling.
- Type
Union[jax._src.random.PRNGKey, jaxlib.xla_extension.NamedSharding]
- model_kwargs#
A dictionary containing any additional arguments required by the model for the next generation step (e.g., attention cache/past key-values). The structure depends on the specific model implementation.
- Type
Union[Dict[str, jax.Array], jaxlib.xla_extension.NamedSharding]
- tokens_per_second#
Estimated generation speed in tokens per second. Defaults to -inf.
- Type
Optional[float]
- generated_tokens#
The total count of tokens generated across all sequences in the current generation process up to this state. Defaults to 0.
- Type
Optional[int]
- padded_length#
The target length to which sequences are padded. This might be different from max_sequence_length in some scenarios. Defaults to 0.
- Type
Optional[int]
- _time_spent_computing#
Internal tracker for the cumulative computation time spent to reach this state. Defaults to 0.0.
- Type
Optional[float]
- _compile_config#
The vInferencePreCompileConfig instance used for pre-compiling the functions associated with this generation state. Defaults to None.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- generated_tokens: Optional[int] = 0#
- padded_length: Optional[int] = 0#
- prng_key: Union[PRNGKey, NamedSharding]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tokens_per_second: Optional[float] = -inf#
- class easydel.inference.vinference.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#
Bases:
objectClass for performing text generation using a pre-trained language graphdef in EasyDeL.
This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.
- property SEQUENCE_DIM_MAPPING#
- count_tokens(messages: List[Dict[str, str]], oai_like: bool = False)[source]#
- count_tokens(text: str, oai_like: bool = False)
- execute_decode(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
- execute_prefill(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
Executes a single generation step with performance monitoring.
- generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, sampling_params: Optional[SamplingParams] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#
Generates text in streaming chunks with comprehensive input adjustment.
- Parameters
input_ids – Input token IDs as a JAX array
attention_mask – Optional attention mask for the input
graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.
graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.
**model_kwargs – Additional model-specific keyword arguments
- Returns
Generator yielding SampleState objects containing generation results and metrics
- property inference_name#
- property metrics#
- property model#
- property model_prefill_length: int#
Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.
- Returns
The maximum length available for input prefill
- Return type
int
- Raises
ValueError – If no maximum sequence length configuration is found
- precompile(config: vInferencePreCompileConfig)[source]#
Precompiles the generation functions for a given batch size and input length.
This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.
- Returns
True if precompilation was successful, False otherwise.
- Return type
bool
- process_prompt(prompt: Union[str, List[str], List[Dict[str, str]]], sampling_params: Union[SamplingParams, Dict], stream: bool = False) Union[PromptOutput, List[PromptOutput], Generator[str, None, SampleState], List[Generator[str, None, SampleState]]][source]#
Processes a prompt (string, list of strings, or OpenAI messages) and generates a response.
- Parameters
prompt – The input prompt. Can be a single string, a list of strings (processed sequentially), or a list of dictionaries representing OpenAI-style chat messages (processed as a single batch).
sampling_params – Configuration for the generation process (temperature, top_p, etc.). Can be a SamplingParams object or a dictionary.
stream – If True, yields generated tokens incrementally. If False, returns the complete generation(s) at the end.
- Returns
A PromptOutput object containing the full text and metrics. - If input is str or List[Dict] and stream=True: A generator yielding string chunks. The generator’s return value (accessible via try/except StopIteration) is the final SampleState. - If input is List[str] and stream=False: A list of PromptOutput objects. - If input is List[str] and stream=True: A list where each element is a generator as described above for the single stream case.
- Return type
If input is str or List[Dict] and stream=False
- Raises
TypeError – If the prompt format is invalid or processor does not support chat templates.
ValueError – If tokenization or processing fails.
- async process_prompts_concurrently(prompts: List[Union[str, List[Dict[str, str]]]], max_concurrent_requests: int, sampling_params: Optional[SamplingParams] = None, stream: bool = False, progress_callback: Optional[Callable[[int, int], None]] = None) Union[List[PromptOutput], AsyncGenerator[Tuple[int, str, Any], None]][source]#
Processes a list of prompts concurrently, supporting both streaming and non-streaming modes.
- Parameters
prompts – A list of prompts (strings or OpenAI-style message lists).
max_concurrent_requests – The maximum number of prompts to process in parallel. If <= 0, processing will be sequential (but still async).
sampling_params – Optional sampling parameters to override the default ones for this batch of requests. Passed to process_prompt.
stream – If True, returns an async generator yielding tuples of (index, type, data). If False, returns a list of PromptOutput objects.
progress_callback – An optional function called after each prompt finishes processing. Receives (completed_count, total_count).
- Returns
- A list of PromptOutput objects containing the full text, metrics,
or error information for each prompt, in the original order.
- If stream=True: An async generator yielding tuples (index, type, data) where:
type is ‘text’ and data is the string chunk.
type is ‘error’ and data is the error string.
type is ‘final’ and data is the final PromptOutput object with metrics.
The generator finishes when all prompts are processed.
- Return type
If stream=False
- Raises
ValueError – If input arguments are invalid.
RuntimeError – If an unexpected error occurs during processing.
- property tokenizer#
- class easydel.inference.vinference.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10, allow_parallel_workload: bool = False, oai_like_processor: bool = True)[source]#
Bases:
objectFastAPI server for serving vInference instances.
This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.
- async chat_completions(request: ChatCompletionRequest)[source]#
Handles chat completion requests (POST /v1/chat/completions).
Validates the request, retrieves the appropriate vInference model, tokenizes the input, and delegates to streaming or non-streaming handlers.
- Parameters
request (ChatCompletionRequest) – The incoming request data.
- Returns
- The generated response, either
a complete JSON object or a streaming event-stream.
- Return type
Union[JSONResponse, StreamingResponse]
- async completions(request: CompletionRequest)[source]#
Handles completion requests (POST /v1/completions).
Processes the prompt for completion and returns generated text.
- Parameters
request (CompletionRequest) – The incoming request data.
- Returns
The generated response.
- Return type
Union[JSONResponse, StreamingResponse]
- async count_tokens(request: CountTokenRequest)[source]#
Token counting endpoint (POST /v1/count_tokens).
- fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#
Starts the uvicorn server to run the FastAPI application.
- Parameters
host (str) – The host address to bind to. Defaults to “0.0.0.0”.
port (int) – The port to listen on. Defaults to 11556.
metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.
log_level (str) – The logging level for uvicorn. Defaults to “info”.
ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.
ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.
- class easydel.inference.vinference.__init__.vInferenceConfig(max_new_tokens: int = 64, streaming_chunks: int = 16, num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[PartitionAxis] = None, _loop_rows: Optional[int] = None, sampling_params: Optional[SamplingParams] = None)[source]#
Bases:
objectConfiguration class for the vInference engine, controlling the overall generation process.
This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings.
- max_new_tokens#
The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64.
- Type
int
- streaming_chunks#
The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16.
- Type
int
- num_return_sequences#
The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from vInferencePreCompileConfig)
to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1.
- Type
Optional[Union[int, Dict[int, int]]]
- pad_token_id#
The token ID used for padding sequences. If None, the model’s default pad token ID might be used, or padding might not be applied.
- Type
Optional[int]
- bos_token_id#
The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic.
- Type
Optional[int]
- eos_token_id#
The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers.
- Type
Optional[Union[int, List[int]]]
- partition_rules#
A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model’s parameters and intermediate states (like attention cache). If None, default rules based on partition_axis are generated. Example: ((“.*kernel.*”, PartitionSpec(“fsdp”, None)), …)
- Type
Optional[Tuple[Tuple[str, Any]]]
- partition_axis#
A PartitionAxis object defining the logical names for sharding axes (e.g., ‘batch’, ‘sequence’, ‘head’). Required if partition_rules is None, used to generate default sharding rules.
- Type
- _loop_rows#
(Internal) The calculated number of iterations needed in the generation loop based on max_new_tokens and streaming_chunks. Automatically computed in __post_init__.
- Type
Optional[int]
- sampling_params#
A SamplingParams object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default SamplingParams instance with max_tokens set to max_new_tokens is created in __post_init__.
- Type
- bos_token_id: Optional[int] = None#
- eos_token_id: Optional[Union[int, List[int]]] = None#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None) Tuple[Tuple[str, Any], ...][source]#
Generates or retrieves the sharding partition rules for the vInference engine.
If self.partition_rules is already set (custom rules provided), it returns them directly.
Otherwise, it constructs a default set of partition rules based on the axis names defined in self.partition_axis. These default rules aim to provide sensible sharding for common model components: - Input sequences (sequences, running_token) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats
(8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes.
Any parameters/states not matching the specific rules are replicated by default (.*).
- Parameters
runtime_config – An optional vInferencePreCompileConfig. Currently unused in the default rule generation but available for potential customization in subclasses or future versions.
- Returns
A regex pattern (string) matching parameter or state names.
A jax.sharding.PartitionSpec defining how the matched items should be sharded.
- Return type
A tuple of partition rules. Each rule is a tuple containing
- Raises
AssertionError – If self.partition_rules is None and self.partition_axis is also None, as axis names are required to generate default rules.
- max_new_tokens: int = 64#
- num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
- pad_token_id: Optional[int] = None#
- partition_axis: Optional[PartitionAxis] = None#
- partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sampling_params: Optional[SamplingParams] = None#
- streaming_chunks: int = 16#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.vinference.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Optional[Union[int, List[int]]] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#
Bases:
objectConfiguration class for pre-compiling vInference functions.
This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios.
- batch_size#
Batch size or list of batch sizes for text generation.
- Type
Union[int, List[int]]
- prefill_length#
Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context.
- Type
Optional[Union[int, List[int]]]
- vision_included#
Whether vision inputs are included in the model.
- Type
Union[bool, List[bool]]
- vision_batch_size#
Batch size for vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_channels#
Number of channels for vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_height#
Height of vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_width#
Width of vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- required_props#
Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments).
- Type
Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]]
- batch_size: Union[int, List[int]] = 1#
- classmethod create_optimized_compo(batch_size: Union[int, List[int]] = 1, max_prefill_length: int = 2048, min_prefill_length: int = 64, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#
- extract() dict[source]#
Converts the configuration instance into a dictionary.
This method is useful for serialization or easily accessing all configuration values.
- Returns
A dictionary representation of the vInferencePreCompileConfig instance.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_default_hash() int[source]#
Generates a unique integer hash representing the configuration.
This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration.
- Returns
An integer hash value representing the configuration.
- get_standalones() List[vInferencePreCompileConfig][source]#
Generates a list of standalone configurations from a potentially multi-value config.
If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual vInferencePreCompileConfig instances. Each resulting instance represents a single, specific compilation scenario.
If an attribute’s list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes.
If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance.
- Returns
A list of vInferencePreCompileConfig instances, each representing a single, standalone compilation scenario.
- prefill_length: Optional[Union[int, List[int]]] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- vision_batch_size: Optional[Union[int, List[int]]] = None#
- vision_channels: Optional[Union[int, List[int]]] = None#
- vision_height: Optional[Union[int, List[int]]] = None#
- vision_included: Union[bool, List[bool]] = False#
- vision_width: Optional[Union[int, List[int]]] = None#