easydel.inference.utilities#

class easydel.inference.utilities.SamplingParams(max_tokens: int = 16, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.7, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, suppress_tokens: list[int] = <factory>)[source]#

Bases: object

Parameters controlling the sampling process during text generation.

max_tokens#

The maximum number of tokens to generate (excluding the prompt). Defaults to 16.

Type

int

presence_penalty#

Penalty applied to the logits of tokens already present in the generated sequence. Positive values discourage repetition. Defaults to 0.0.

Type

float

frequency_penalty#

Penalty applied to the logits of tokens based on their frequency in the generated sequence so far. Positive values discourage verbatim repetition. Defaults to 0.0.

Type

float

repetition_penalty#

Multiplicative penalty applied to the logits of previously seen tokens. Values > 1.0 discourage repetition, < 1.0 encourage it. Defaults to 1.0.

Type

float

temperature#

Controls the randomness of the sampling. Higher values (e.g., > 1.0) make the distribution flatter (more random), lower values (e.g., < 1.0) make it peakier (more deterministic). A value of 0.0 effectively becomes greedy sampling. Defaults to 0.0.

Type

float

top_p#

Nucleus sampling threshold. If set to a value < 1.0, only the most probable tokens with a cumulative probability exceeding top_p are considered for sampling. Defaults to 1.0 (no nucleus sampling).

Type

float

top_k#

Top-k sampling threshold. If set to a value > 0, only the top_k most probable tokens are considered for sampling. Defaults to 0 (no top-k sampling).

Type

int

min_p#

Minimum probability threshold. Filters out tokens with probability less than min_p. Defaults to 0.0 (no minimum probability filtering).

Type

float

suppress_tokens#

A list of token IDs that should be completely suppressed (their logits set to -inf) during generation. Defaults to an empty list.

Type

list[int]

frequency_penalty: float = 0.0#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_logits_processor()[source]#

Constructs a LogitsProcessorList containing the configured logits processors.

Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.

Returns

A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.

get_logits_warper()[source]#

Constructs a LogitsProcessorList containing the configured logits warpers.

Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.

Returns

A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.

property logits_processor#
property logits_warper#
max_tokens: int = 16#
min_p: float = 0.0#
presence_penalty: float = 0.0#
repetition_penalty: float = 1.0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

suppress_tokens: list[int]#
temperature: float = 0.7#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: int = 0#
top_p: float = 1.0#
easydel.inference.utilities.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Compiles a JAX function, potentially ready for distributed execution, after lowering it.

This function first lowers the JAX function using lower_function and then calls .compile() on the lowered representation to produce an executable.

Parameters
  • func โ€“ The JAX function to compile.

  • func_input_args โ€“ A tuple of positional arguments for the function.

  • func_input_kwargs โ€“ A dictionary of keyword arguments for the function.

  • mesh โ€“ An optional jax.sharding.Mesh object for distributed execution.

  • in_shardings โ€“ Optional sharding specifications for the input arguments.

  • out_shardings โ€“ Optional sharding specifications for the output.

  • static_argnums โ€“ Indices of static positional arguments.

  • donate_argnums โ€“ Indices of positional arguments to donate.

Returns

A compiled JAX function (typically a jax.stages.Compiled object).

easydel.inference.utilities.lower_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Lowers a JAX function to its HLO (High-Level Optimizer) representation, optionally configuring sharding and device mesh for distributed execution.

Lowering separates the definition of the computation from its compilation, allowing for inspection or manipulation of the HLO before final compilation.

Parameters
  • func โ€“ The JAX function to lower.

  • func_input_args โ€“ A tuple of positional arguments for the function.

  • func_input_kwargs โ€“ A dictionary of keyword arguments for the function.

  • mesh โ€“ An optional jax.sharding.Mesh object specifying the device topology for distributed execution. If provided, jax.jit is called within the mesh context.

  • in_shardings โ€“ Optional sharding specifications for the input arguments. Can be a PyTree matching the structure of func_input_args and func_input_kwargs, containing jax.sharding.PartitionSpec objects.

  • out_shardings โ€“ Optional sharding specifications for the output of the function. Can be a PyTree matching the functionโ€™s output structure.

  • static_argnums โ€“ Indices of positional arguments that should be treated as static (compile-time constants).

  • donate_argnums โ€“ Indices of positional arguments whose underlying buffers can be donated (potentially modified in-place) to save memory.

Returns

A jax.Lowered object representing the HLO computation.