easydel.inference.sampling_params

Contents

easydel.inference.sampling_params#

Sampling parameters and configurations for text generation.

This module defines the parameters that control text generation behavior, including temperature, top-k, top-p, penalties, and other sampling strategies. It provides both host-side Python configurations and JAX-compatible versions for efficient on-device processing.

Classes:

SamplingType: Enum for sampling strategies (greedy vs random) RequestOutputKind: Enum for output formats (cumulative, delta, final) GuidedDecodingParams: Parameters for constrained generation JitableSamplingParams: JAX-compatible sampling parameters SamplingParams: High-level sampling configuration BeamSearchParams: Parameters for beam search decoding

Example

>>> from easydel.inference import SamplingParams
>>> params = SamplingParams(
...     temperature=0.8,
...     top_p=0.95,
...     max_tokens=100,
...     stop=["
“, “END”]

… ) >>> # Convert to JAX-compatible format >>> jit_params = params.make_jitable()

class easydel.inference.sampling_params.BeamSearchParams(beam_width: int, max_tokens: int, ignore_eos: bool = False, temperature: float = 0.0, length_penalty: float = 1.0, include_stop_str_in_output: bool = False)[source]#

Bases: object

Beam search parameters for text generation.

Configuration for beam search decoding, which maintains multiple candidate sequences and selects the best ones based on cumulative probability. This class is immutable (frozen=True) for JAX compatibility.

beam_width#

Number of beams to maintain

Type

int

max_tokens#

Maximum tokens to generate

Type

int

ignore_eos#

Whether to ignore end-of-sequence token

Type

bool

temperature#

Temperature for beam search (0.0 = greedy)

Type

float

length_penalty#

Penalty factor for sequence length

Type

float

include_stop_str_in_output#

Include stop string in output

Type

bool

Note

Beam search is typically used when deterministic, high-quality output is desired, at the cost of increased computation.

beam_width: int#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

ignore_eos: bool = False#
include_stop_str_in_output: bool = False#
length_penalty: float = 1.0#
max_tokens: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: float = 0.0#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.sampling_params.GuidedDecodingParams(json: str | dict | None = None, regex: str | None = None, choice: list[str] | None = None, grammar: str | None = None, json_object: bool | None = None, backend: str | None = None, backend_was_auto: bool = False, disable_fallback: bool = False, disable_any_whitespace: bool = False, disable_additional_properties: bool = False, whitespace_pattern: str | None = None, structural_tag: str | None = None)[source]#

Bases: object

Parameters for guided decoding.

Enables constrained generation to match specific formats or patterns. Only one guided decoding mode can be active at a time.

json#

JSON schema or object for structured output

Type

str | dict | None

regex#

Regular expression pattern to match

Type

str | None

choice#

List of allowed string choices

Type

list[str] | None

grammar#

Context-free grammar specification

Type

str | None

json_object#

Force output to be valid JSON object

Type

bool | None

backend#

Decoding backend to use

Type

str | None

backend_was_auto#

Whether backend was auto-selected

Type

bool

disable_fallback#

Disable fallback to unconstrained generation

Type

bool

disable_any_whitespace#

Disable whitespace in structured output

Type

bool

disable_additional_properties#

Restrict JSON to defined properties only

Type

bool

whitespace_pattern#

Custom whitespace pattern

Type

str | None

structural_tag#

Tag for structural elements

Type

str | None

Raises

ValueError – If multiple guided decoding modes are specified

backend: str | None = None#
backend_was_auto: bool = False#
choice: list[str] | None = None#
disable_additional_properties: bool = False#
disable_any_whitespace: bool = False#
disable_fallback: bool = False#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

grammar: str | None = None#
json: str | dict | None = None#
json_object: bool | None = None#
regex: str | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

structural_tag: str | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

whitespace_pattern: str | None = None#
class easydel.inference.sampling_params.JitableSamplingParams(random_sampling: Array, temperature: Array, top_k: Array, top_p: Array, min_p: Array, repetition_penalty: Array, frequency_penalty: Array, presence_penalty: Array, max_tokens: Array, min_tokens: Array, all_stop_token_ids: Array, bad_words_token_ids: Array, bad_words_lengths: Array, allowed_token_ids: jax.Array | None = None)[source]#

Bases: object

A JAX-native, device-ready version of sampling parameters.

This class contains only JAX arrays and static information, making it suitable for passing into jit-compiled functions. All Python-specific types like strings and lists have been converted or removed.

all_stop_token_ids: Array#
allowed_token_ids: jax.Array | None = None#
bad_words_lengths: Array#
bad_words_token_ids: Array#
frequency_penalty: Array#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_host_params(params: SamplingParams) JitableSamplingParams[source]#

Converts the host-side SamplingParams to a JIT-compatible version.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_logits_processor()[source]#

Constructs a LogitsProcessorList containing the configured logits processors.

Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.

Returns

A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.

get_logits_warper()[source]#

Constructs a LogitsProcessorList containing the configured logits warpers.

Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.

Returns

A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.

classmethod init_empty(batch_size: int)[source]#
insert(second_sample: JitableSamplingParams, slot: int) JitableSamplingParams[source]#
property logits_processor#
property logits_warper#
make_jitable()[source]#
max_tokens: Array#
min_p: Array#
min_tokens: Array#
presence_penalty: Array#
random_sampling: Array#
repetition_penalty: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: Array#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: Array#
top_p: Array#
view_1d() JitableSamplingParams[source]#
view_2d() JitableSamplingParams[source]#
class easydel.inference.sampling_params.RequestOutputKind(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Defines the kind of output for a request.

CUMULATIVE#

Return the full generated text so far

DELTA#

Return only newly generated tokens

FINAL_ONLY#

Return only the final complete output

CUMULATIVE = 0#
DELTA = 1#
FINAL_ONLY = 2#
class easydel.inference.sampling_params.SamplingParams(n: int = 1, best_of: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 1.0, top_p: float = 1.0, min_p: float = 0.0, top_k: int = 0, seed: int | None = None, stop: list[str] = <factory>, stop_token_ids: list[int] = <factory>, stop_pattern: str | None = None, bad_words: list[str] = <factory>, ignore_eos: bool = False, max_tokens: int | None = 16, min_tokens: int = 0, logprobs: int | None = None, prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, include_stop_str_in_output: bool = False, output_kind: ~easydel.inference.sampling_params.RequestOutputKind = RequestOutputKind.CUMULATIVE, truncate_prompt_tokens: ~typing.Optional[int] = None, guided_decoding: easydel.inference.sampling_params.GuidedDecodingParams | None = None, logit_bias: dict[int, float] | None = None, allowed_token_ids: list[int] | None = None, extra_args: dict[str, typing.Any] = <factory>)[source]#

Bases: object

Sampling parameters for text generation.

Comprehensive configuration for controlling text generation behavior. Supports both OpenAI-compatible parameters and EasyDeL-specific extensions.

Attributes:

n: Number of sequences to generate best_of: Sample n sequences and return the best one presence_penalty: Penalty for tokens based on presence (-2.0 to 2.0) frequency_penalty: Penalty for tokens based on frequency (-2.0 to 2.0) repetition_penalty: Multiplicative penalty for repeated tokens temperature: Controls randomness (0.0 = deterministic, higher = more random) top_p: Nucleus sampling threshold (0.0 to 1.0) min_p: Minimum probability threshold relative to top token top_k: Number of highest probability tokens to consider seed: Random seed for reproducibility stop: List of stop strings stop_token_ids: List of stop token IDs stop_pattern: Regex pattern string for stopping generation bad_words: List of strings to avoid generating ignore_eos: Whether to ignore end-of-sequence token max_tokens: Maximum number of tokens to generate min_tokens: Minimum number of tokens to generate logprobs: Number of log probabilities to return prompt_logprobs: Number of prompt log probabilities to return detokenize: Whether to convert tokens to text skip_special_tokens: Whether to skip special tokens in output spaces_between_special_tokens: Add spaces between special tokens include_stop_str_in_output: Include stop string in output output_kind: Type of output to return truncate_prompt_tokens: Truncate prompt to this many tokens guided_decoding: Parameters for constrained generation logit_bias: Bias to apply to specific token logits allowed_token_ids: Whitelist of allowed token IDs extra_args: Additional custom arguments

Example:
>>> params = SamplingParams(
...     temperature=0.7,
...     top_p=0.9,
...     max_tokens=100,
...     stop=["
“]

… )

property all_stop_token_ids: set[int]#

Returns all stop token IDs, including EOS.

allowed_token_ids: list[int] | None = None#
bad_words: list[str]#
property bad_words_token_ids: list[list[int]] | None#

Returns the tokenized versions of bad_words.

best_of: int | None = None#
clone() SamplingParams[source]#

Creates a deep copy of the instance.

detokenize: bool = True#
extra_args: dict[str, Any]#
frequency_penalty: float = 0.0#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

guided_decoding: easydel.inference.sampling_params.GuidedDecodingParams | None = None#
ignore_eos: bool = False#
include_stop_str_in_output: bool = False#
logit_bias: dict[int, float] | None = None#
logprobs: int | None = None#
make_jitable() JitableSamplingParams[source]#

Converts this host-side configuration into a JAX-jittable object.

This method should be called after all pre-processing (like tokenization) is complete.

max_tokens: int | None = 16#
min_p: float = 0.0#
min_tokens: int = 0#
n: int = 1#
output_kind: RequestOutputKind = 0#
presence_penalty: float = 0.0#
prompt_logprobs: int | None = None#
repetition_penalty: float = 1.0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

property sampling_type: SamplingType#

Determines the sampling type based on parameters.

seed: int | None = None#
skip_special_tokens: bool = True#
spaces_between_special_tokens: bool = True#
stop: list[str]#
stop_pattern: str | None = None#
stop_token_ids: list[int]#
temperature: float = 1.0#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: int = 0#
top_p: float = 1.0#
truncate_prompt_tokens: Optional[int] = None#
update_with_generation_config(generation_config: dict[str, Any], model_eos_token_id: int | None = None) SamplingParams[source]#

Creates a new SamplingParams instance updated with a model’s generation_config. Returns a new instance to maintain immutability.

update_with_tokenizer(tokenizer: AutoTokenizer) SamplingParams[source]#

Creates a new SamplingParams instance with bad_words encoded into token IDs. Returns a new instance to maintain immutability.

class easydel.inference.sampling_params.SamplingType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: IntEnum

Defines the sampling strategy.

GREEDY#

Deterministic selection of highest probability token

RANDOM#

Probabilistic sampling from the distribution

GREEDY = 0#
RANDOM = 1#