# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling parameters and configurations for text generation.
This module defines the parameters that control text generation behavior,
including temperature, top-k, top-p, penalties, and other sampling strategies.
It provides both host-side Python configurations and JAX-compatible versions
for efficient on-device processing.
Classes:
SamplingType: Enum for sampling strategies (greedy vs random)
RequestOutputKind: Enum for output formats (cumulative, delta, final)
GuidedDecodingParams: Parameters for constrained generation
JitableSamplingParams: JAX-compatible sampling parameters
SamplingParams: High-level sampling configuration
BeamSearchParams: Parameters for beam search decoding
Example:
>>> from easydel.inference import SamplingParams
>>> params = SamplingParams(
... temperature=0.8,
... top_p=0.95,
... max_tokens=100,
... stop=["\n", "END"]
... )
>>> # Convert to JAX-compatible format
>>> jit_params = params.make_jitable()
"""
from __future__ import annotations
import copy
import dataclasses
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
from typing import Annotated, Any
import jax
from eformer.escale import with_sharding_constraint
from eformer.loggings import get_logger
from eformer.pytree import auto_pytree
from eformer.pytree import field as pytree_field
from jax import numpy as jnp
from jax.sharding import PartitionSpec
from transformers import AutoTokenizer
from .logits_process import (
FrequencyPenaltyLogitsProcessor,
LogitsProcessorList,
MinPLogitsWarper,
PresencePenaltyLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
logger = get_logger(__name__)
[docs]class SamplingType(IntEnum):
"""Defines the sampling strategy.
Attributes:
GREEDY: Deterministic selection of highest probability token
RANDOM: Probabilistic sampling from the distribution
"""
GREEDY = 0
RANDOM = 1
[docs]class RequestOutputKind(Enum):
"""Defines the kind of output for a request.
Attributes:
CUMULATIVE: Return the full generated text so far
DELTA: Return only newly generated tokens
FINAL_ONLY: Return only the final complete output
"""
CUMULATIVE = 0
DELTA = 1
FINAL_ONLY = 2
[docs]@auto_pytree
class GuidedDecodingParams:
"""Parameters for guided decoding.
Enables constrained generation to match specific formats or patterns.
Only one guided decoding mode can be active at a time.
Attributes:
json: JSON schema or object for structured output
regex: Regular expression pattern to match
choice: List of allowed string choices
grammar: Context-free grammar specification
json_object: Force output to be valid JSON object
backend: Decoding backend to use
backend_was_auto: Whether backend was auto-selected
disable_fallback: Disable fallback to unconstrained generation
disable_any_whitespace: Disable whitespace in structured output
disable_additional_properties: Restrict JSON to defined properties only
whitespace_pattern: Custom whitespace pattern
structural_tag: Tag for structural elements
Raises:
ValueError: If multiple guided decoding modes are specified
"""
json: str | dict | None = pytree_field(pytree_node=False, default=None)
regex: str | None = pytree_field(pytree_node=False, default=None)
choice: list[str] | None = pytree_field(pytree_node=False, default=None)
grammar: str | None = pytree_field(pytree_node=False, default=None)
json_object: bool | None = pytree_field(pytree_node=False, default=None)
backend: str | None = pytree_field(pytree_node=False, default=None)
backend_was_auto: bool = pytree_field(pytree_node=False, default=False)
disable_fallback: bool = pytree_field(pytree_node=False, default=False)
disable_any_whitespace: bool = pytree_field(pytree_node=False, default=False)
disable_additional_properties: bool = pytree_field(pytree_node=False, default=False)
whitespace_pattern: str | None = pytree_field(pytree_node=False, default=None)
structural_tag: str | None = pytree_field(pytree_node=False, default=None)
def __post_init__(self):
"""Validates that only one guided decoding mode is specified.
Raises:
ValueError: If more than one guided decoding mode is active
"""
guide_count = sum(
(
self.json is not None,
self.regex is not None,
self.choice is not None,
self.grammar is not None,
self.json_object is not None,
)
)
if guide_count > 1:
raise ValueError(
f"Only one guided decoding mode can be used, but multiple were specified: {dataclasses.asdict(self)}"
)
[docs]@auto_pytree(frozen=True)
class JitableSamplingParams:
"""
A JAX-native, device-ready version of sampling parameters.
This class contains only JAX arrays and static information, making it
suitable for passing into jit-compiled functions. All Python-specific
types like strings and lists have been converted or removed.
"""
random_sampling: jax.Array # [1] bool
temperature: jax.Array
top_k: jax.Array
top_p: jax.Array
min_p: jax.Array
repetition_penalty: jax.Array
frequency_penalty: jax.Array
presence_penalty: jax.Array
max_tokens: jax.Array
min_tokens: jax.Array
all_stop_token_ids: jax.Array
bad_words_token_ids: jax.Array # Padded to be rectangular
bad_words_lengths: jax.Array # Stores the true length of each bad word sequence
allowed_token_ids: jax.Array | None = None
[docs] def insert(self, second_sample: JitableSamplingParams, slot: int) -> JitableSamplingParams:
self = self.view_1d()
def update_idx1d(x, y):
sharding = getattr(x, "sharding", PartitionSpec())
return with_sharding_constraint(jax.lax.dynamic_update_slice(x, y, (slot,)), sharding)
return JitableSamplingParams(
random_sampling=update_idx1d(self.random_sampling, second_sample.random_sampling),
temperature=update_idx1d(self.temperature, second_sample.temperature),
top_k=update_idx1d(self.top_k, second_sample.top_k),
top_p=update_idx1d(self.top_p, second_sample.top_p),
min_p=update_idx1d(self.min_p, second_sample.min_p),
repetition_penalty=update_idx1d(self.repetition_penalty, second_sample.repetition_penalty),
frequency_penalty=update_idx1d(self.frequency_penalty, second_sample.frequency_penalty),
presence_penalty=update_idx1d(self.presence_penalty, second_sample.presence_penalty),
max_tokens=update_idx1d(self.max_tokens, second_sample.max_tokens),
min_tokens=update_idx1d(self.min_tokens, second_sample.min_tokens),
all_stop_token_ids=self.all_stop_token_ids,
bad_words_token_ids=self.bad_words_token_ids,
bad_words_lengths=self.bad_words_lengths,
allowed_token_ids=self.allowed_token_ids,
)
[docs] def view_1d(self) -> JitableSamplingParams:
return JitableSamplingParams(
random_sampling=self.random_sampling.reshape(-1),
temperature=self.temperature.reshape(-1),
top_k=self.top_k.reshape(-1),
top_p=self.top_p.reshape(-1),
min_p=self.min_p.reshape(-1),
repetition_penalty=self.repetition_penalty.reshape(-1),
frequency_penalty=self.frequency_penalty.reshape(-1),
presence_penalty=self.presence_penalty.reshape(-1),
max_tokens=self.max_tokens.reshape(-1),
min_tokens=self.min_tokens.reshape(-1),
all_stop_token_ids=self.all_stop_token_ids.reshape(-1),
bad_words_token_ids=self.bad_words_token_ids.reshape(-1),
bad_words_lengths=self.bad_words_lengths.reshape(-1),
allowed_token_ids=self.allowed_token_ids.reshape(-1) if self.allowed_token_ids is not None else None,
)
[docs] def view_2d(self) -> JitableSamplingParams:
return JitableSamplingParams(
random_sampling=self.random_sampling.reshape(-1, 1),
temperature=self.temperature.reshape(-1, 1),
top_k=self.top_k.reshape(-1, 1),
top_p=self.top_p.reshape(-1, 1),
min_p=self.min_p.reshape(-1, 1),
repetition_penalty=self.repetition_penalty.reshape(-1, 1),
frequency_penalty=self.frequency_penalty.reshape(-1, 1),
presence_penalty=self.presence_penalty.reshape(-1, 1),
max_tokens=self.max_tokens.reshape(-1, 1),
min_tokens=self.min_tokens.reshape(-1, 1),
all_stop_token_ids=self.all_stop_token_ids.reshape(-1, 1),
bad_words_token_ids=self.bad_words_token_ids.reshape(-1, 1),
bad_words_lengths=self.bad_words_lengths.reshape(-1, 1),
allowed_token_ids=self.allowed_token_ids.reshape(-1, 1) if self.allowed_token_ids is not None else None,
)
[docs] @classmethod
def init_empty(cls, batch_size: int):
return cls(
random_sampling=jnp.zeros([batch_size], dtype="b1"),
temperature=jnp.zeros([batch_size], dtype="f4"),
top_k=jnp.zeros([batch_size], dtype="i4"),
top_p=jnp.zeros([batch_size], dtype="f4"),
min_p=jnp.zeros([batch_size], dtype="f4"),
repetition_penalty=jnp.zeros([batch_size], dtype="f4"),
frequency_penalty=jnp.zeros([batch_size], dtype="f4"),
presence_penalty=jnp.zeros([batch_size], dtype="f4"),
max_tokens=jnp.zeros([batch_size], dtype="i4"),
min_tokens=jnp.zeros([batch_size], dtype="i4"),
all_stop_token_ids=jnp.array([[]], dtype=jnp.int32),
bad_words_token_ids=jnp.array([[]], dtype=jnp.int32),
bad_words_lengths=jnp.array([[]], dtype=jnp.int32),
)
[docs] @classmethod
def from_host_params(cls, params: SamplingParams) -> JitableSamplingParams:
"""Converts the host-side SamplingParams to a JIT-compatible version."""
if params._bad_words_token_ids:
max_len = max(len(ids) for ids in params._bad_words_token_ids)
lengths = jnp.array([len(ids) for ids in params._bad_words_token_ids], dtype=jnp.int32)
padded_ids = jnp.array(
[ids + [-100] * (max_len - len(ids)) for ids in params._bad_words_token_ids],
dtype=jnp.int32,
)
else:
lengths = jnp.array([], dtype=jnp.int32)
padded_ids = jnp.array([[]], dtype=jnp.int32)
return cls(
random_sampling=jnp.asarray(params.sampling_type.value == SamplingType.RANDOM.value, dtype=jnp.bool),
temperature=jnp.asarray(params.temperature, dtype=jnp.float32),
top_k=jnp.asarray(params.top_k, dtype=jnp.int32),
top_p=jnp.asarray(params.top_p, dtype=jnp.float32),
min_p=jnp.asarray(params.min_p, dtype=jnp.float32),
repetition_penalty=jnp.asarray(params.repetition_penalty, dtype=jnp.float32),
frequency_penalty=jnp.asarray(params.frequency_penalty, dtype=jnp.float32),
presence_penalty=jnp.asarray(params.presence_penalty, dtype=jnp.float32),
max_tokens=jnp.asarray(params.max_tokens if params.max_tokens is not None else -1, dtype=jnp.int32),
min_tokens=jnp.asarray(params.min_tokens, dtype=jnp.int32),
all_stop_token_ids=jnp.asarray(list(params.all_stop_token_ids), dtype=jnp.int32),
bad_words_token_ids=padded_ids,
bad_words_lengths=lengths,
allowed_token_ids=jnp.asarray(params.allowed_token_ids, dtype=jnp.int32)
if params.allowed_token_ids
else None,
)
[docs] def get_logits_warper(self):
"""
Constructs a `LogitsProcessorList` containing the configured logits warpers.
Logits warpers modify the probability distribution derived from logits, typically
used for techniques like temperature scaling, top-k, top-p, and min-p sampling.
Returns:
A `LogitsProcessorList` containing the enabled logits warpers based on the
sampling parameters.
"""
warpers = LogitsProcessorList()
warpers.append(TemperatureLogitsWarper(temperature=self.temperature))
warpers.append(TopKLogitsWarper(top_k=self.top_k, min_tokens_to_keep=1))
warpers.append(TopPLogitsWarper(top_p=self.top_p, min_tokens_to_keep=1))
warpers.append(MinPLogitsWarper(min_p=self.min_p, min_tokens_to_keep=1))
return warpers
[docs] def get_logits_processor(self):
"""
Constructs a `LogitsProcessorList` containing the configured logits processors.
Logits processors modify the logits directly, often used for applying penalties
(presence, frequency, repetition) or suppressing specific tokens.
Returns:
A `LogitsProcessorList` containing the enabled logits processors based on the
sampling parameters.
"""
processors = LogitsProcessorList()
processors.append(PresencePenaltyLogitsProcessor(self.presence_penalty))
processors.append(FrequencyPenaltyLogitsProcessor(self.frequency_penalty))
processors.append(RepetitionPenaltyLogitsProcessor(self.repetition_penalty))
return processors
@cached_property
def logits_processor(self):
return self.get_logits_processor()
@cached_property
def logits_warper(self):
return self.get_logits_warper()
[docs] def make_jitable(self):
return self
[docs]@auto_pytree
class SamplingParams:
"""Sampling parameters for text generation.
Comprehensive configuration for controlling text generation behavior.
Supports both OpenAI-compatible parameters and EasyDeL-specific extensions.
Attributes:
n: Number of sequences to generate
best_of: Sample n sequences and return the best one
presence_penalty: Penalty for tokens based on presence (-2.0 to 2.0)
frequency_penalty: Penalty for tokens based on frequency (-2.0 to 2.0)
repetition_penalty: Multiplicative penalty for repeated tokens
temperature: Controls randomness (0.0 = deterministic, higher = more random)
top_p: Nucleus sampling threshold (0.0 to 1.0)
min_p: Minimum probability threshold relative to top token
top_k: Number of highest probability tokens to consider
seed: Random seed for reproducibility
stop: List of stop strings
stop_token_ids: List of stop token IDs
stop_pattern: Regex pattern string for stopping generation
bad_words: List of strings to avoid generating
ignore_eos: Whether to ignore end-of-sequence token
max_tokens: Maximum number of tokens to generate
min_tokens: Minimum number of tokens to generate
logprobs: Number of log probabilities to return
prompt_logprobs: Number of prompt log probabilities to return
detokenize: Whether to convert tokens to text
skip_special_tokens: Whether to skip special tokens in output
spaces_between_special_tokens: Add spaces between special tokens
include_stop_str_in_output: Include stop string in output
output_kind: Type of output to return
truncate_prompt_tokens: Truncate prompt to this many tokens
guided_decoding: Parameters for constrained generation
logit_bias: Bias to apply to specific token logits
allowed_token_ids: Whitelist of allowed token IDs
extra_args: Additional custom arguments
Example:
>>> params = SamplingParams(
... temperature=0.7,
... top_p=0.9,
... max_tokens=100,
... stop=["\n\n"]
... )
"""
n: int = 1
best_of: int | None = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
min_p: float = 0.0
top_k: int = 0
seed: int | None = None
# Stopping Conditions
stop: list[str] = field(default_factory=list)
stop_token_ids: list[int] = field(default_factory=list)
stop_pattern: str | None = None # Regex pattern for stopping generation
bad_words: list[str] = field(default_factory=list)
ignore_eos: bool = False
max_tokens: int | None = 16
min_tokens: int = 0
# Output Control
logprobs: int | None = None
prompt_logprobs: int | None = None
detokenize: bool = True
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
include_stop_str_in_output: bool = False
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# Advanced & Guided Decoding
truncate_prompt_tokens: Annotated[int, "ge=1"] | None = None
guided_decoding: GuidedDecodingParams | None = None
logit_bias: dict[int, float] | None = None
allowed_token_ids: list[int] | None = None
extra_args: dict[str, Any] = field(default_factory=dict)
# Internal fields computed during initialization or via update methods.
# They are not part of the constructor (`init=False`).
_real_n: int | None = field(default=None, init=False)
_output_text_buffer_length: int = field(default=0, init=False)
_all_stop_token_ids: set[int] = field(default_factory=set, init=False)
_bad_words_token_ids: list[list[int]] | None = field(default=None, init=False)
def __post_init__(self) -> None:
"""
Initializes and validates parameters.
"""
if self.best_of is not None:
if self.best_of < self.n:
raise ValueError(f"best_of ({self.best_of}) must be >= n ({self.n}).")
if not self._real_n:
self._real_n = self.n
self.n = self.best_of
if self.temperature is None:
self.temperature = 1
if 0 < self.temperature < 1e-2:
logger.warning(
f"temperature {self.temperature} is below {1e-2}, which may cause numerical instability. "
f"Clamping to {1e-2}."
)
self.temperature = 1e-2
if self.seed == -1:
self.seed = None
if self.max_tokens is not None and self.max_tokens < 0:
logger.debug("Received negative max_tokens (%s); treating as auto-infer.", self.max_tokens)
self.max_tokens = None
if isinstance(self.stop, str):
self.stop = [self.stop]
if self.logprobs is True:
self.logprobs = 1
if self.prompt_logprobs is True:
self.prompt_logprobs = 1
if self.stop and not self.include_stop_str_in_output:
buffer_len = max(len(s) for s in self.stop) - 1
self._output_text_buffer_length = buffer_len
self._verify_args()
if self.temperature < 1e-5:
self.top_p = 1.0
self.top_k = 0
self.min_p = 0.0
self._verify_greedy_sampling()
self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
"""Performs detailed validation of parameter values."""
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(f"presence_penalty must be in [-2, 2], got {self.presence_penalty}.")
if self.temperature < 0.0:
raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(f"min_tokens ({self.min_tokens}) must be <= max_tokens ({self.max_tokens}).")
if self.stop and not self.detokenize:
raise ValueError("stop strings require detokenize=True.")
def _verify_greedy_sampling(self) -> None:
"""Validates parameters for greedy sampling."""
if self.n > 1:
raise ValueError(f"n must be 1 for greedy sampling, got {self.n}.")
[docs] def update_with_generation_config(
self,
generation_config: dict[str, Any],
model_eos_token_id: int | None = None,
) -> SamplingParams:
"""
Creates a new `SamplingParams` instance updated with a model's generation_config.
Returns a new instance to maintain immutability.
"""
all_stop_ids = self._all_stop_token_ids.copy()
if model_eos_token_id is not None:
all_stop_ids.add(model_eos_token_id)
new_stop_token_ids = self.stop_token_ids
if (eos_ids := generation_config.get("eos_token_id")) is not None:
eos_ids_set = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
eos_ids_set.discard(model_eos_token_id)
if eos_ids_set and not self.ignore_eos:
new_stop_token_ids = list(set(self.stop_token_ids) | eos_ids_set)
all_stop_ids.update(eos_ids_set)
self.stop_token_ids = new_stop_token_ids
self._all_stop_token_ids = all_stop_ids
return self
[docs] def update_with_tokenizer(self, tokenizer: AutoTokenizer) -> SamplingParams:
"""
Creates a new `SamplingParams` instance with bad_words encoded into token IDs.
Returns a new instance to maintain immutability.
"""
if not self.bad_words:
return self
bad_words_token_ids = []
for word in self.bad_words:
for add_prefix_space in [False, True]:
text = (" " if add_prefix_space else "") + word.lstrip()
token_ids = tokenizer.encode(text=text, add_special_tokens=False)
if not add_prefix_space or (
add_prefix_space and len(bad_words_token_ids) > 0 and token_ids != bad_words_token_ids[-1]
):
bad_words_token_ids.append(token_ids)
vocab_size = getattr(tokenizer, "vocab_size", tokenizer.model_max_length)
invalid_ids = [tid for ids in bad_words_token_ids for tid in ids if not (0 <= tid < vocab_size)]
if invalid_ids:
raise ValueError(
f"Bad words resulted in invalid token IDs: {invalid_ids}. "
f"All token IDs must be within the vocab size of {vocab_size}."
)
self._bad_words_token_ids = bad_words_token_ids
return self
@cached_property
def sampling_type(self) -> SamplingType:
"""Determines the sampling type based on parameters."""
if self.temperature < 1e-5:
return SamplingType.GREEDY
return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> set[int]:
"""Returns all stop token IDs, including EOS."""
return self._all_stop_token_ids
@property
def bad_words_token_ids(self) -> list[list[int]] | None:
"""Returns the tokenized versions of bad_words."""
return self._bad_words_token_ids
[docs] def make_jitable(self) -> JitableSamplingParams:
"""
Converts this host-side configuration into a JAX-jittable object.
This method should be called after all pre-processing (like tokenization)
is complete.
"""
if self.bad_words and self._bad_words_token_ids is None:
raise RuntimeError("Must call `with_tokenizer()` before `make_jitable()` when `bad_words` is set.")
return JitableSamplingParams.from_host_params(self)
[docs] def clone(self) -> SamplingParams:
"""Creates a deep copy of the instance."""
return copy.deepcopy(self)
[docs]@auto_pytree
class BeamSearchParams:
"""Beam search parameters for text generation.
Configuration for beam search decoding, which maintains multiple
candidate sequences and selects the best ones based on cumulative probability.
This class is immutable (`frozen=True`) for JAX compatibility.
Attributes:
beam_width: Number of beams to maintain
max_tokens: Maximum tokens to generate
ignore_eos: Whether to ignore end-of-sequence token
temperature: Temperature for beam search (0.0 = greedy)
length_penalty: Penalty factor for sequence length
include_stop_str_in_output: Include stop string in output
Note:
Beam search is typically used when deterministic, high-quality
output is desired, at the cost of increased computation.
"""
beam_width: int
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False