easydel.inference.logits_process#
Logits processing and warping utilities for text generation.
This module provides a comprehensive set of logits processors and warpers for controlling text generation behavior in language models. These utilities allow fine-grained control over the probability distributions during sampling.
- Key Components:
LogitsProcessor: Base class for modifying logits before sampling
LogitsWarper: Base class for rescaling probability distributions
Various specialized processors for temperature, top-k, top-p, penalties, etc.
Example
>>> from easydel.inference.logits_process import (
... LogitsProcessorList,
... TemperatureLogitsWarper,
... TopKLogitsWarper
... )
>>> processors = LogitsProcessorList()
>>> processors.append(TemperatureLogitsWarper(temperature=0.7))
>>> processors.append(TopKLogitsWarper(top_k=50))
>>> # Apply processors to logits during generation
>>> processed_scores = processors(input_ids, scores, cur_len)
- class easydel.inference.logits_process.EmptyProcessor[source]#
Bases:
LogitsProcessorA placeholder LogitsProcessor that performs no operation.
This processor simply returns the input scores unchanged. It can be useful in configurations where a processor slot needs to be filled but no actual processing is desired at that stage.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.ForceTokensLogitsProcessor(force_token_map)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that forces specific tokens to be generated at predefined positions during the generation.
This processor uses a mapping (force_token_map) where keys are the generation indices (0-based, relative to the start of generation) and values are the token IDs to be forced at those indices.
When the current generation step cur_len matches an index in the map, the logit of the corresponding forced token ID is set to 0 (probability 1), and all other logits are set to filter_value (-infinity).
- Parameters
force_token_map – A mapping from generation index to the token ID to force. Can be provided as a dict or a list of (index, token_id) pairs.
- class easydel.inference.logits_process.ForcedBOSTokenLogitsProcessor(bos_token_id: int)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that ensures the beginning-of-sequence (BOS) token is generated as the very first token.
This processor modifies the logits only at the first generation step (cur_len = 1). It sets the logit of the bos_token_id to 0 (probability 1) and all other logits to filter_value (-infinity).
- Parameters
bos_token_id – The integer ID of the Beginning-Of-Sequence token.
- bos_token_id: int#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.ForcedEOSTokenLogitsProcessor(max_length: int, eos_token_id: int)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that forces the end-of-sequence (EOS) token to be generated when the generation process reaches the predefined max_length.
This processor modifies the logits only at the step where cur_len equals max_length - 1. It sets the logit of the eos_token_id to 0 (probability 1) and all other logits to filter_value (-infinity).
- Parameters
max_length – The maximum allowed sequence length (including prompt).
eos_token_id – The integer ID of the End-Of-Sequence token.
- eos_token_id: int#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_length: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.FrequencyPenaltyLogitsProcessor(frequency_penalty: float)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that penalizes tokens based on their frequency (number of occurrences) in the sequence generated so far (input_ids).
This processor subtracts a penalty proportional to the token’s count from its logit. The penalty is calculated as count * frequency_penalty. Positive penalties discourage the model from repeating specific tokens frequently.
- Parameters
frequency_penalty – The penalty factor. Must be non-negative. Defaults to 0.0 (no penalty).
- frequency_penalty: float#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.LogitsProcessor[source]#
Bases:
objectAbstract base class for all logit processors.
Logits processors are callable classes that modify the logits predicted by a language model before sampling. They are used to implement various decoding strategies and constraints, such as forcing specific tokens, applying penalties, or preventing repetitions.
Inheriting classes should implement the __call__ method taking input_ids, scores, and cur_len as arguments.
- class easydel.inference.logits_process.LogitsProcessorList(iterable=(), /)[source]#
Bases:
listA container class, inheriting from list, designed to hold a sequence of LogitsProcessor and LogitsWarper objects.
The primary purpose of this class is to provide a convenient way to apply a chain of processors/warpers sequentially to a set of logits. It overrides the __call__ method to iterate through the contained objects and apply each one to the logits.
It intelligently handles processors that might require additional keyword arguments by inspecting their __call__ method signatures using inspect.signature.
- class easydel.inference.logits_process.LogitsWarper[source]#
Bases:
objectAbstract base class for all logit warpers.
Logit warpers are callable classes that modify the logits predicted by a language model after potential processing but before sampling, typically by re-scaling or filtering the probability distribution. They are used for techniques like temperature scaling, top-k, and top-p (nucleus) sampling.
Inheriting classes should implement the __call__ method taking input_ids, scores, and cur_len as arguments.
- class easydel.inference.logits_process.MinLengthLogitsProcessor(min_length: int, eos_token_id: int)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that prevents the generation of the end-of-sequence (EOS) token until a minimum sequence length (min_length) has been reached.
This processor sets the logit of the eos_token_id to filter_value (-infinity) if the current sequence length cur_len is less than min_length.
- Parameters
min_length – The minimum number of tokens that must be generated before the EOS token is allowed.
eos_token_id – The integer ID of the End-Of-Sequence token.
- eos_token_id: int#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- min_length: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.MinPLogitsWarper(min_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#
Bases:
LogitsWarper[LogitsWarper] implementing min-p sampling.
Filters the vocabulary distribution by removing tokens whose probability P(token) is less than min_p times the probability of the most likely token P(max). That is, it keeps tokens where P(token) >= min_p * P(max).
This is an alternative filtering strategy to top-p or top-k.
- Parameters
min_p – The minimum probability threshold relative to the peak probability. Must be in [0, 1]. Setting min_p=0.0 disables the filter.
filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.
min_tokens_to_keep – Minimum number of tokens to retain, even if their probability falls below the min_p * P(max) threshold. Defaults to 1.
- filter_value: float = -inf#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- min_p: float#
- min_tokens_to_keep: int = 1#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.NoRepeatNGramLogitsProcessor(ngram_size: int)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that prevents the generation of n-grams that have already occurred in the sequence generated so far (input_ids).
At each step, it considers the last ngram_size - 1 tokens generated. It then identifies all tokens in the vocabulary that would complete an n-gram already present in the full input_ids sequence. The logits for these banned tokens are set to filter_value (-infinity).
Reference: [Fairseq Sequence Generator](pytorch/fairseq)
- Parameters
ngram_size – The size of the n-gram to prevent from repeating. Setting ngram_size=0 disables the processor.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ngram_size: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.PresencePenaltyLogitsProcessor(presence_penalty: float)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that penalizes tokens based on their presence in the sequence generated so far (input_ids).
This processor subtracts a fixed presence_penalty value from the logits of all tokens that have appeared at least once in the input_ids. Positive penalties discourage the model from reusing tokens, promoting topic diversity.
- Parameters
presence_penalty – The penalty value subtracted from the logits of present tokens. Must be non-negative. Defaults to 0.0 (no penalty).
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- presence_penalty: float#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.RepetitionPenaltyLogitsProcessor(repetition_penalty: float)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that applies a multiplicative penalty to the logits of tokens that have already appeared in the generated sequence (input_ids).
For previously seen tokens: - If the original logit is positive, it’s divided by repetition_penalty. - If the original logit is negative, it’s multiplied by repetition_penalty.
This aims to discourage repetition.
Reference: [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Keskar et al. (2019).
- Parameters
repetition_penalty – The penalty factor. Must be positive. - 1.0 means no penalty. - Values > 1.0 discourage repetition. - Values < 1.0 encourage repetition. Defaults to 1.0.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- repetition_penalty: float#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens: list, begin_index: int)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that suppresses a specified list of tokens only at a specific early step in the generation process.
This is useful for preventing certain tokens (like BOS) from being generated immediately after the prompt.
The suppression occurs only when cur_len equals begin_index. The logits of the begin_suppress_tokens are set to filter_value (-infinity) at that step.
- Parameters
begin_suppress_tokens – A list or tuple of token IDs to suppress at the start.
begin_index – The generation step index (0-based relative to the start of generation) at which to apply the suppression.
- begin_index: int#
- begin_suppress_tokens: list#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.SuppressTokensLogitsProcessor(suppress_tokens: list)[source]#
Bases:
LogitsProcessor[LogitsProcessor] that suppresses a specified list of tokens throughout the entire generation process.
This processor sets the logits of the suppress_tokens to filter_value (-infinity) at every generation step where the list is not empty.
- Parameters
suppress_tokens – A list or tuple of token IDs to suppress consistently.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- suppress_tokens: list#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.TemperatureLogitsWarper(temperature: Array)[source]#
Bases:
LogitsWarper[LogitsWarper] that applies temperature scaling to the logits distribution.
Divides the logits by the temperature value. A temperature of 0.0 or 1.0 results in no change. Temperatures below 1.0 make the distribution sharper (less random), while temperatures above 1.0 make it flatter (more random).
- Parameters
temperature – The temperature value for scaling. Must be non-negative. Setting to 0.0 disables the warper effectively.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.logits_process.TopKLogitsWarper(top_k: int, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#
Bases:
LogitsWarper[LogitsWarper] that implements top-k sampling.
Filters the vocabulary distribution by keeping only the top_k tokens with the highest probabilities (logits). The logits of the filtered tokens are set to filter_value.
- Parameters
top_k – The number of highest probability tokens to keep. Setting top_k=0 disables the filter.
filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.
min_tokens_to_keep – Minimum number of tokens to retain, overriding top_k if top_k is smaller. Ensures at least this many tokens are considered. Defaults to 1.
- filter_value: float = -inf#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- min_tokens_to_keep: int = 1#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- top_k: int#
- class easydel.inference.logits_process.TopPLogitsWarper(top_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#
Bases:
LogitsWarper[LogitsWarper] that implements top-p (nucleus) sampling.
Filters the vocabulary distribution by keeping only the smallest set of tokens whose cumulative probability mass exceeds the threshold top_p. The logits of the filtered tokens are set to filter_value.
Reference: [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) by Holtzman et al. (2019).
- Parameters
top_p – The cumulative probability threshold. Must be in (0, 1]. Setting top_p=1.0 disables the filter.
filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.
min_tokens_to_keep – Minimum number of tokens to retain, even if their cumulative probability exceeds top_p. Defaults to 1.
- filter_value: float = -inf#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- min_tokens_to_keep: int = 1#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- top_p: float#
- class easydel.inference.logits_process.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#
Bases:
LogitsProcessorA specialized [LogitsProcessor] tailored for handling timestamp tokens during generation with Whisper-style models used for Automatic Speech Recognition (ASR).
It enforces several constraints specific to timestamp prediction: 1. Suppresses `<|notimestamps|>`: Prevents the model from predicting the token
that indicates the absence of timestamps.
Alternating Tokens: Enforces that text tokens and timestamp tokens generally alternate. If the last generated token was a timestamp, it biases against predicting another timestamp immediately after (unless it’s the very beginning or certain edge cases).
Initial Timestamp Limit: Restricts the maximum value of the first timestamp token predicted using max_initial_timestamp_index.
Timestamp Probability Check: If the total probability mass assigned to all valid timestamp tokens is higher than the probability of the single most likely non-timestamp token, it forces the model to sample a timestamp token by suppressing all non-timestamp tokens.
Note
This processor assumes the existence of specific token IDs related to timestamps (e.g., eos_token_id, no_timestamps_token_id, timestamp_begin) which are typically defined in the model’s generation configuration.
- Parameters
generate_config – Configuration object containing Whisper-specific generation parameters like eos_token_id, no_timestamps_token_id, is_multilingual, max_initial_timestamp_index.
model_config – The model’s configuration (used for vocab_size as a fallback).
decoder_input_length – The length of the initial input sequence provided to the decoder (e.g., the prompt length).
- easydel.inference.logits_process.add_start_docstrings(*docstr)[source]#
A decorator that prepends a given docstring section to the decorated function’s docstring.
This is useful for adding standard documentation sections (like parameter descriptions) to multiple functions without repetition.
- Parameters
*docstr – One or more strings that will be joined and prepended to the decorated function’s existing docstring.
- Returns
A decorator function.