easydel.inference.__init__

Contents

easydel.inference.__init__#

class easydel.inference.__init__.SamplingParams(max_tokens: int = 16, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.0, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, suppress_tokens: list[int] = <factory>)[source]#

Bases: object

Parameters controlling the sampling process during text generation.

max_tokens#

The maximum number of tokens to generate (excluding the prompt). Defaults to 16.

Type

int

presence_penalty#

Penalty applied to the logits of tokens already present in the generated sequence. Positive values discourage repetition. Defaults to 0.0.

Type

float

frequency_penalty#

Penalty applied to the logits of tokens based on their frequency in the generated sequence so far. Positive values discourage verbatim repetition. Defaults to 0.0.

Type

float

repetition_penalty#

Multiplicative penalty applied to the logits of previously seen tokens. Values > 1.0 discourage repetition, < 1.0 encourage it. Defaults to 1.0.

Type

float

temperature#

Controls the randomness of the sampling. Higher values (e.g., > 1.0) make the distribution flatter (more random), lower values (e.g., < 1.0) make it peakier (more deterministic). A value of 0.0 effectively becomes greedy sampling. Defaults to 0.0.

Type

float

top_p#

Nucleus sampling threshold. If set to a value < 1.0, only the most probable tokens with a cumulative probability exceeding top_p are considered for sampling. Defaults to 1.0 (no nucleus sampling).

Type

float

top_k#

Top-k sampling threshold. If set to a value > 0, only the top_k most probable tokens are considered for sampling. Defaults to 0 (no top-k sampling).

Type

int

min_p#

Minimum probability threshold. Filters out tokens with probability less than min_p. Defaults to 0.0 (no minimum probability filtering).

Type

float

suppress_tokens#

A list of token IDs that should be completely suppressed (their logits set to -inf) during generation. Defaults to an empty list.

Type

list[int]

frequency_penalty: float = 0.0#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_logits_processor()[source]#

Constructs a LogitsProcessorList containing the configured logits processors.

Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.

Returns

A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.

get_logits_warper()[source]#

Constructs a LogitsProcessorList containing the configured logits warpers.

Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.

Returns

A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.

max_tokens: int = 16#
min_p: float = 0.0#
presence_penalty: float = 0.0#
repetition_penalty: float = 1.0#
replace(**kwargs)#
suppress_tokens: list[int]#
temperature: float = 0.0#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

top_k: int = 0#
top_p: float = 1.0#
class easydel.inference.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#

Bases: object

Class for performing text generation using a pre-trained language graphdef in EasyDeL.

This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.

property SEQUENCE_DIM_MAPPING#
adjust_kwargs(input_ids: Array, attention_mask: Optional[Array] = None, **model_kwargs)[source]#
count_tokens(messages: List[Dict[str, str]])[source]#
count_tokens(text: str)
execute_decode(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
execute_prefill(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#

Executes a single generation step with performance monitoring.

generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, sampling_params: Optional[SamplingParams] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#

Generates text in streaming chunks with comprehensive input adjustment.

Parameters
  • input_ids – Input token IDs as a JAX array

  • attention_mask – Optional attention mask for the input

  • graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.

  • graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.

  • **model_kwargs – Additional model-specific keyword arguments

Returns

Generator yielding SampleState objects containing generation results and metrics

property inference_name#
classmethod load_inference(path: Union[PathLike, str], model: None, processor_class: None)[source]#
property metrics#
property model#
property model_prefill_length: int#

Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.

Returns

The maximum length available for input prefill

Return type

int

Raises

ValueError – If no maximum sequence length configuration is found

precompile(config: vInferencePreCompileConfig)[source]#

Precompiles the generation functions for a given batch size and input length.

This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.

Returns

True if precompilation was successful, False otherwise.

Return type

bool

save_inference(path: Union[PathLike, str])[source]#
property tokenizer#
class easydel.inference.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10)[source]#

Bases: object

FastAPI server for serving vInference instances.

This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.

async available_inference()[source]#

Lists available models (GET /v1/models).

async chat_completions(request: ChatCompletionRequest)[source]#

Handles chat completion requests (POST /v1/chat/completions).

Validates the request, retrieves the appropriate vInference model, tokenizes the input, and delegates to streaming or non-streaming handlers.

Parameters

request (ChatCompletionRequest) – The incoming request data.

Returns

The generated response, either

a complete JSON object or a streaming event-stream.

Return type

Union[JSONResponse, StreamingResponse]

async completions(request: CompletionRequest)[source]#

Handles completion requests (POST /v1/completions).

Processes the prompt for completion and returns generated text.

Parameters

request (CompletionRequest) – The incoming request data.

Returns

The generated response.

Return type

Union[JSONResponse, StreamingResponse]

async count_tokens(request: CountTokenRequest)[source]#

Token counting endpoint (POST /v1/count_tokens).

fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#

Starts the uvicorn server to run the FastAPI application.

Parameters
  • host (str) – The host address to bind to. Defaults to “0.0.0.0”.

  • port (int) – The port to listen on. Defaults to 11556.

  • metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.

  • log_level (str) – The logging level for uvicorn. Defaults to “info”.

  • ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.

  • ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.

async liveness()[source]#

Liveness check endpoint (GET /liveness).

async readiness()[source]#

Readiness check endpoint (GET /readiness).

class easydel.inference.__init__.vInferenceConfig(max_new_tokens: int = 64, streaming_chunks: int = 16, num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[PartitionAxis] = None, _loop_rows: Optional[int] = None, sampling_params: Optional[SamplingParams] = None)[source]#

Bases: object

Configuration class for the vInference engine, controlling the overall generation process.

This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings.

max_new_tokens#

The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64.

Type

int

streaming_chunks#

The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16.

Type

int

num_return_sequences#

The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from vInferencePreCompileConfig)

to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1.

Type

Optional[Union[int, Dict[int, int]]]

pad_token_id#

The token ID used for padding sequences. If None, the model’s default pad token ID might be used, or padding might not be applied.

Type

Optional[int]

bos_token_id#

The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic.

Type

Optional[int]

eos_token_id#

The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers.

Type

Optional[Union[int, List[int]]]

partition_rules#

A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model’s parameters and intermediate states (like attention cache). If None, default rules based on partition_axis are generated. Example: ((“.*kernel.*”, PartitionSpec(“fsdp”, None)), …)

Type

Optional[Tuple[Tuple[str, Any]]]

partition_axis#

A PartitionAxis object defining the logical names for sharding axes (e.g., ‘batch’, ‘sequence’, ‘head’). Required if partition_rules is None, used to generate default sharding rules.

Type

Optional[eformer.escale.partition.constraints.PartitionAxis]

_loop_rows#

(Internal) The calculated number of iterations needed in the generation loop based on max_new_tokens and streaming_chunks. Automatically computed in __post_init__.

Type

Optional[int]

sampling_params#

A SamplingParams object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default SamplingParams instance with max_tokens set to max_new_tokens is created in __post_init__.

Type

Optional[easydel.inference.utilities.SamplingParams]

bos_token_id: Optional[int] = None#
eos_token_id: Optional[Union[int, List[int]]] = None#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None) Tuple[Tuple[str, Any], ...][source]#

Generates or retrieves the sharding partition rules for the vInference engine.

If self.partition_rules is already set (custom rules provided), it returns them directly.

Otherwise, it constructs a default set of partition rules based on the axis names defined in self.partition_axis. These default rules aim to provide sensible sharding for common model components: - Input sequences (sequences, running_token) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats

(8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes.

  • Any parameters/states not matching the specific rules are replicated by default (.*).

Parameters

runtime_config – An optional vInferencePreCompileConfig. Currently unused in the default rule generation but available for potential customization in subclasses or future versions.

Returns

  • A regex pattern (string) matching parameter or state names.

  • A jax.sharding.PartitionSpec defining how the matched items should be sharded.

Return type

A tuple of partition rules. Each rule is a tuple containing

Raises

AssertionError – If self.partition_rules is None and self.partition_axis is also None, as axis names are required to generate default rules.

max_new_tokens: int = 64#
num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
pad_token_id: Optional[int] = None#
partition_axis: Optional[PartitionAxis] = None#
partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
replace(**kwargs)#
sampling_params: Optional[SamplingParams] = None#
streaming_chunks: int = 16#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.inference.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Optional[Union[int, List[int]]] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#

Bases: object

Configuration class for pre-compiling vInference functions.

This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios.

batch_size#

Batch size or list of batch sizes for text generation.

Type

Union[int, List[int]]

prefill_length#

Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context.

Type

Optional[Union[int, List[int]]]

vision_included#

Whether vision inputs are included in the model.

Type

Union[bool, List[bool]]

vision_batch_size#

Batch size for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_channels#

Number of channels for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_height#

Height of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_width#

Width of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

required_props#

Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments).

Type

Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]]

batch_size: Union[int, List[int]] = 1#
extract() dict[source]#

Converts the configuration instance into a dictionary.

This method is useful for serialization or easily accessing all configuration values.

Returns

A dictionary representation of the vInferencePreCompileConfig instance.

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_default_hash() int[source]#

Generates a unique integer hash representing the configuration.

This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration.

Returns

An integer hash value representing the configuration.

get_standalones() List[vInferencePreCompileConfig][source]#

Generates a list of standalone configurations from a potentially multi-value config.

If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual vInferencePreCompileConfig instances. Each resulting instance represents a single, specific compilation scenario.

If an attribute’s list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes.

If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance.

Returns

A list of vInferencePreCompileConfig instances, each representing a single, standalone compilation scenario.

prefill_length: Optional[Union[int, List[int]]] = None#
replace(**kwargs)#
required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

vision_batch_size: Optional[Union[int, List[int]]] = None#
vision_channels: Optional[Union[int, List[int]]] = None#
vision_height: Optional[Union[int, List[int]]] = None#
vision_included: Union[bool, List[bool]] = False#
vision_width: Optional[Union[int, List[int]]] = None#
class easydel.inference.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.inference.vwhisper.config.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#

Bases: object

Whisper inference pipeline for performing speech-to-text transcription or translation.

Parameters
  • model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.

  • tokenizer (WhisperTokenizer) – Tokenizer for Whisper.

  • processor (WhisperProcessor) – Processor for Whisper.

  • inference_config (vWhisperInferenceConfig, optional) – Inference configuration.

  • dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.

generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#

Transcribe or translate audio input.

Parameters
  • audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.

  • chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.

  • stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.

  • batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.

  • language (str, optional) – Language of the input audio. Defaults to the language in inference_config.

  • task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.

Returns

A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.

Return type

dict

class easydel.inference.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#

Bases: object

Configuration class for Whisper inference.

Parameters
  • batch_size (int, optional, defaults to 1) – Batch size used for inference.

  • max_length (int, optional) – Maximum sequence length for generation.

  • generation_config (transformers.GenerationConfig, optional) – Generation configuration object.

  • logits_processor (optional) – Not used.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.

  • task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).

  • language (str, optional) – Language of the input audio.

  • is_multilingual (bool, optional) – Whether the model is multilingual.

batch_size: Optional[int] = 1#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

generation_config: Optional[Any] = None#
is_multilingual = None#
language = None#
logits_processor = None#
max_length: Optional[int] = None#
replace(**kwargs)#
return_timestamps = None#
task = None#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.