easydel.inference.sampling_params#
Sampling parameters and configurations for text generation.
This module defines the parameters that control text generation behavior, including temperature, top-k, top-p, penalties, and other sampling strategies. It provides both host-side Python configurations and JAX-compatible versions for efficient on-device processing.
- Classes:
SamplingType: Enum for sampling strategies (greedy vs random) RequestOutputKind: Enum for output formats (cumulative, delta, final) GuidedDecodingParams: Parameters for constrained generation JitableSamplingParams: JAX-compatible sampling parameters SamplingParams: High-level sampling configuration BeamSearchParams: Parameters for beam search decoding
Example
>>> from easydel.inference import SamplingParams
>>> params = SamplingParams(
... temperature=0.8,
... top_p=0.95,
... max_tokens=100,
... stop=["
- “, “END”]
… ) >>> # Convert to JAX-compatible format >>> jit_params = params.make_jitable()
- class easydel.inference.sampling_params.BeamSearchParams(beam_width: int, max_tokens: int, ignore_eos: bool = False, temperature: float = 0.0, length_penalty: float = 1.0, include_stop_str_in_output: bool = False)[source]#
Bases:
objectBeam search parameters for text generation.
Configuration for beam search decoding, which maintains multiple candidate sequences and selects the best ones based on cumulative probability. This class is immutable (frozen=True) for JAX compatibility.
- beam_width#
Number of beams to maintain
- Type
int
- max_tokens#
Maximum tokens to generate
- Type
int
- ignore_eos#
Whether to ignore end-of-sequence token
- Type
bool
- temperature#
Temperature for beam search (0.0 = greedy)
- Type
float
- length_penalty#
Penalty factor for sequence length
- Type
float
- include_stop_str_in_output#
Include stop string in output
- Type
bool
Note
Beam search is typically used when deterministic, high-quality output is desired, at the cost of increased computation.
- beam_width: int#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ignore_eos: bool = False#
- include_stop_str_in_output: bool = False#
- length_penalty: float = 1.0#
- max_tokens: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- temperature: float = 0.0#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.sampling_params.GuidedDecodingParams(json: str | dict | None = None, regex: str | None = None, choice: list[str] | None = None, grammar: str | None = None, json_object: bool | None = None, backend: str | None = None, backend_was_auto: bool = False, disable_fallback: bool = False, disable_any_whitespace: bool = False, disable_additional_properties: bool = False, whitespace_pattern: str | None = None, structural_tag: str | None = None)[source]#
Bases:
objectParameters for guided decoding.
Enables constrained generation to match specific formats or patterns. Only one guided decoding mode can be active at a time.
- json#
JSON schema or object for structured output
- Type
str | dict | None
- regex#
Regular expression pattern to match
- Type
str | None
- choice#
List of allowed string choices
- Type
list[str] | None
- grammar#
Context-free grammar specification
- Type
str | None
- json_object#
Force output to be valid JSON object
- Type
bool | None
- backend#
Decoding backend to use
- Type
str | None
- backend_was_auto#
Whether backend was auto-selected
- Type
bool
- disable_fallback#
Disable fallback to unconstrained generation
- Type
bool
- disable_any_whitespace#
Disable whitespace in structured output
- Type
bool
- disable_additional_properties#
Restrict JSON to defined properties only
- Type
bool
- whitespace_pattern#
Custom whitespace pattern
- Type
str | None
- structural_tag#
Tag for structural elements
- Type
str | None
- Raises
ValueError – If multiple guided decoding modes are specified
- backend_was_auto: bool = False#
- disable_additional_properties: bool = False#
- disable_any_whitespace: bool = False#
- disable_fallback: bool = False#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.sampling_params.JitableSamplingParams(random_sampling: Array, temperature: Array, top_k: Array, top_p: Array, min_p: Array, repetition_penalty: Array, frequency_penalty: Array, presence_penalty: Array, max_tokens: Array, min_tokens: Array, all_stop_token_ids: Array, bad_words_token_ids: Array, bad_words_lengths: Array, allowed_token_ids: jax.Array | None = None)[source]#
Bases:
objectA JAX-native, device-ready version of sampling parameters.
This class contains only JAX arrays and static information, making it suitable for passing into jit-compiled functions. All Python-specific types like strings and lists have been converted or removed.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_host_params(params: SamplingParams) JitableSamplingParams[source]#
Converts the host-side SamplingParams to a JIT-compatible version.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_logits_processor()[source]#
Constructs a LogitsProcessorList containing the configured logits processors.
Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.
- Returns
A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.
- get_logits_warper()[source]#
Constructs a LogitsProcessorList containing the configured logits warpers.
Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.
- Returns
A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.
- insert(second_sample: JitableSamplingParams, slot: int) JitableSamplingParams[source]#
- property logits_processor#
- property logits_warper#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- view_1d() JitableSamplingParams[source]#
- view_2d() JitableSamplingParams[source]#
- class easydel.inference.sampling_params.RequestOutputKind(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
EnumDefines the kind of output for a request.
- CUMULATIVE#
Return the full generated text so far
- DELTA#
Return only newly generated tokens
- FINAL_ONLY#
Return only the final complete output
- CUMULATIVE = 0#
- DELTA = 1#
- FINAL_ONLY = 2#
- class easydel.inference.sampling_params.SamplingParams(n: int = 1, best_of: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 1.0, top_p: float = 1.0, min_p: float = 0.0, top_k: int = 0, seed: int | None = None, stop: list[str] = <factory>, stop_token_ids: list[int] = <factory>, stop_pattern: str | None = None, bad_words: list[str] = <factory>, ignore_eos: bool = False, max_tokens: int | None = 16, min_tokens: int = 0, logprobs: int | None = None, prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, include_stop_str_in_output: bool = False, output_kind: ~easydel.inference.sampling_params.RequestOutputKind = RequestOutputKind.CUMULATIVE, truncate_prompt_tokens: ~typing.Optional[int] = None, guided_decoding: easydel.inference.sampling_params.GuidedDecodingParams | None = None, logit_bias: dict[int, float] | None = None, allowed_token_ids: list[int] | None = None, extra_args: dict[str, typing.Any] = <factory>)[source]#
Bases:
objectSampling parameters for text generation.
Comprehensive configuration for controlling text generation behavior. Supports both OpenAI-compatible parameters and EasyDeL-specific extensions.
- Attributes:
n: Number of sequences to generate best_of: Sample n sequences and return the best one presence_penalty: Penalty for tokens based on presence (-2.0 to 2.0) frequency_penalty: Penalty for tokens based on frequency (-2.0 to 2.0) repetition_penalty: Multiplicative penalty for repeated tokens temperature: Controls randomness (0.0 = deterministic, higher = more random) top_p: Nucleus sampling threshold (0.0 to 1.0) min_p: Minimum probability threshold relative to top token top_k: Number of highest probability tokens to consider seed: Random seed for reproducibility stop: List of stop strings stop_token_ids: List of stop token IDs stop_pattern: Regex pattern string for stopping generation bad_words: List of strings to avoid generating ignore_eos: Whether to ignore end-of-sequence token max_tokens: Maximum number of tokens to generate min_tokens: Minimum number of tokens to generate logprobs: Number of log probabilities to return prompt_logprobs: Number of prompt log probabilities to return detokenize: Whether to convert tokens to text skip_special_tokens: Whether to skip special tokens in output spaces_between_special_tokens: Add spaces between special tokens include_stop_str_in_output: Include stop string in output output_kind: Type of output to return truncate_prompt_tokens: Truncate prompt to this many tokens guided_decoding: Parameters for constrained generation logit_bias: Bias to apply to specific token logits allowed_token_ids: Whitelist of allowed token IDs extra_args: Additional custom arguments
- Example:
>>> params = SamplingParams( ... temperature=0.7, ... top_p=0.9, ... max_tokens=100, ... stop=["
- “]
… )
- property all_stop_token_ids: set[int]#
Returns all stop token IDs, including EOS.
- bad_words: list[str]#
- clone() SamplingParams[source]#
Creates a deep copy of the instance.
- detokenize: bool = True#
- extra_args: dict[str, Any]#
- frequency_penalty: float = 0.0#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- guided_decoding: easydel.inference.sampling_params.GuidedDecodingParams | None = None#
- ignore_eos: bool = False#
- include_stop_str_in_output: bool = False#
- make_jitable() JitableSamplingParams[source]#
Converts this host-side configuration into a JAX-jittable object.
This method should be called after all pre-processing (like tokenization) is complete.
- min_p: float = 0.0#
- min_tokens: int = 0#
- n: int = 1#
- output_kind: RequestOutputKind = 0#
- presence_penalty: float = 0.0#
- repetition_penalty: float = 1.0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- property sampling_type: SamplingType#
Determines the sampling type based on parameters.
- skip_special_tokens: bool = True#
- spaces_between_special_tokens: bool = True#
- stop: list[str]#
- stop_token_ids: list[int]#
- temperature: float = 1.0#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- top_k: int = 0#
- top_p: float = 1.0#
- truncate_prompt_tokens: Optional[int] = None#
- update_with_generation_config(generation_config: dict[str, Any], model_eos_token_id: int | None = None) SamplingParams[source]#
Creates a new SamplingParams instance updated with a model’s generation_config. Returns a new instance to maintain immutability.
- update_with_tokenizer(tokenizer: AutoTokenizer) SamplingParams[source]#
Creates a new SamplingParams instance with bad_words encoded into token IDs. Returns a new instance to maintain immutability.
- class easydel.inference.sampling_params.SamplingType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
IntEnumDefines the sampling strategy.
- GREEDY#
Deterministic selection of highest probability token
- RANDOM#
Probabilistic sampling from the distribution
- GREEDY = 0#
- RANDOM = 1#