easydel.inference.logits_process

Contents

easydel.inference.logits_process#

class easydel.inference.logits_process.EmptyProcessor[source]#

Bases: LogitsProcessor

A placeholder LogitsProcessor that performs no operation.

This processor simply returns the input scores unchanged. It can be useful in configurations where a processor slot needs to be filled but no actual processing is desired at that stage.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.ForceTokensLogitsProcessor(force_token_map)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that forces specific tokens to be generated at predefined positions during the generation.

This processor uses a mapping (force_token_map) where keys are the generation indices (0-based, relative to the start of generation) and values are the token IDs to be forced at those indices.

When the current generation step cur_len matches an index in the map, the logit of the corresponding forced token ID is set to 0 (probability 1), and all other logits are set to filter_value (-infinity).

Parameters

force_token_map – A mapping from generation index to the token ID to force. Can be provided as a dict or a list of (index, token_id) pairs.

class easydel.inference.logits_process.ForcedBOSTokenLogitsProcessor(bos_token_id: int)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that ensures the beginning-of-sequence (BOS) token is generated as the very first token.

This processor modifies the logits only at the first generation step (cur_len = 1). It sets the logit of the bos_token_id to 0 (probability 1) and all other logits to filter_value (-infinity).

Parameters

bos_token_id – The integer ID of the Beginning-Of-Sequence token.

bos_token_id: int#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.ForcedEOSTokenLogitsProcessor(max_length: int, eos_token_id: int)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that forces the end-of-sequence (EOS) token to be generated when the generation process reaches the predefined max_length.

This processor modifies the logits only at the step where cur_len equals max_length - 1. It sets the logit of the eos_token_id to 0 (probability 1) and all other logits to filter_value (-infinity).

Parameters
  • max_length – The maximum allowed sequence length (including prompt).

  • eos_token_id – The integer ID of the End-Of-Sequence token.

eos_token_id: int#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_length: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.FrequencyPenaltyLogitsProcessor(frequency_penalty: float)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that penalizes tokens based on their frequency (number of occurrences) in the sequence generated so far (input_ids).

This processor subtracts a penalty proportional to the token’s count from its logit. The penalty is calculated as count * frequency_penalty. Positive penalties discourage the model from repeating specific tokens frequently.

Parameters

frequency_penalty – The penalty factor. Must be non-negative. Defaults to 0.0 (no penalty).

frequency_penalty: float#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.LogitsProcessor[source]#

Bases: object

Abstract base class for all logit processors.

Logits processors are callable classes that modify the logits predicted by a language model before sampling. They are used to implement various decoding strategies and constraints, such as forcing specific tokens, applying penalties, or preventing repetitions.

Inheriting classes should implement the __call__ method taking input_ids, scores, and cur_len as arguments.

class easydel.inference.logits_process.LogitsProcessorList(iterable=(), /)[source]#

Bases: list

A container class, inheriting from list, designed to hold a sequence of LogitsProcessor and LogitsWarper objects.

The primary purpose of this class is to provide a convenient way to apply a chain of processors/warpers sequentially to a set of logits. It overrides the __call__ method to iterate through the contained objects and apply each one to the logits.

It intelligently handles processors that might require additional keyword arguments by inspecting their __call__ method signatures using inspect.signature.

class easydel.inference.logits_process.LogitsWarper[source]#

Bases: object

Abstract base class for all logit warpers.

Logit warpers are callable classes that modify the logits predicted by a language model after potential processing but before sampling, typically by re-scaling or filtering the probability distribution. They are used for techniques like temperature scaling, top-k, and top-p (nucleus) sampling.

Inheriting classes should implement the __call__ method taking input_ids, scores, and cur_len as arguments.

class easydel.inference.logits_process.MinLengthLogitsProcessor(min_length: int, eos_token_id: int)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that prevents the generation of the end-of-sequence (EOS) token until a minimum sequence length (min_length) has been reached.

This processor sets the logit of the eos_token_id to filter_value (-infinity) if the current sequence length cur_len is less than min_length.

Parameters
  • min_length – The minimum number of tokens that must be generated before the EOS token is allowed.

  • eos_token_id – The integer ID of the End-Of-Sequence token.

eos_token_id: int#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

min_length: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.MinPLogitsWarper(min_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#

Bases: LogitsWarper

[LogitsWarper] implementing min-p sampling.

Filters the vocabulary distribution by removing tokens whose probability P(token) is less than min_p times the probability of the most likely token P(max). That is, it keeps tokens where P(token) >= min_p * P(max).

This is an alternative filtering strategy to top-p or top-k.

Parameters
  • min_p – The minimum probability threshold relative to the peak probability. Must be in [0, 1]. Setting min_p=0.0 disables the filter.

  • filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.

  • min_tokens_to_keep – Minimum number of tokens to retain, even if their probability falls below the min_p * P(max) threshold. Defaults to 1.

filter_value: float = -inf#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

min_p: float#
min_tokens_to_keep: int = 1#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.NoRepeatNGramLogitsProcessor(ngram_size: int)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that prevents the generation of n-grams that have already occurred in the sequence generated so far (input_ids).

At each step, it considers the last ngram_size - 1 tokens generated. It then identifies all tokens in the vocabulary that would complete an n-gram already present in the full input_ids sequence. The logits for these banned tokens are set to filter_value (-infinity).

Reference: [Fairseq Sequence Generator](pytorch/fairseq)

Parameters

ngram_size – The size of the n-gram to prevent from repeating. Setting ngram_size=0 disables the processor.

forward(input_ids: Array, scores: Array, cur_len: int) Array[source]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

ngram_size: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.PresencePenaltyLogitsProcessor(presence_penalty: float)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that penalizes tokens based on their presence in the sequence generated so far (input_ids).

This processor subtracts a fixed presence_penalty value from the logits of all tokens that have appeared at least once in the input_ids. Positive penalties discourage the model from reusing tokens, promoting topic diversity.

Parameters

presence_penalty – The penalty value subtracted from the logits of present tokens. Must be non-negative. Defaults to 0.0 (no penalty).

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

presence_penalty: float#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.RepetitionPenaltyLogitsProcessor(repetition_penalty: float)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that applies a multiplicative penalty to the logits of tokens that have already appeared in the generated sequence (input_ids).

For previously seen tokens: - If the original logit is positive, it’s divided by repetition_penalty. - If the original logit is negative, it’s multiplied by repetition_penalty.

This aims to discourage repetition.

Reference: [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Keskar et al. (2019).

Parameters

repetition_penalty – The penalty factor. Must be positive. - 1.0 means no penalty. - Values > 1.0 discourage repetition. - Values < 1.0 encourage repetition. Defaults to 1.0.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

repetition_penalty: float#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens: list, begin_index: int)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that suppresses a specified list of tokens only at a specific early step in the generation process.

This is useful for preventing certain tokens (like BOS) from being generated immediately after the prompt.

The suppression occurs only when cur_len equals begin_index. The logits of the begin_suppress_tokens are set to filter_value (-infinity) at that step.

Parameters
  • begin_suppress_tokens – A list or tuple of token IDs to suppress at the start.

  • begin_index – The generation step index (0-based relative to the start of generation) at which to apply the suppression.

begin_index: int#
begin_suppress_tokens: list#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.SuppressTokensLogitsProcessor(suppress_tokens: list)[source]#

Bases: LogitsProcessor

[LogitsProcessor] that suppresses a specified list of tokens throughout the entire generation process.

This processor sets the logits of the suppress_tokens to filter_value (-infinity) at every generation step where the list is not empty.

Parameters

suppress_tokens – A list or tuple of token IDs to suppress consistently.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

suppress_tokens: list#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.TemperatureLogitsWarper(temperature: float)[source]#

Bases: LogitsWarper

[LogitsWarper] that applies temperature scaling to the logits distribution.

Divides the logits by the temperature value. A temperature of 0.0 or 1.0 results in no change. Temperatures below 1.0 make the distribution sharper (less random), while temperatures above 1.0 make it flatter (more random).

Parameters

temperature – The temperature value for scaling. Must be non-negative. Setting to 0.0 disables the warper effectively.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: float#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.logits_process.TopKLogitsWarper(top_k: int, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#

Bases: LogitsWarper

[LogitsWarper] that implements top-k sampling.

Filters the vocabulary distribution by keeping only the top_k tokens with the highest probabilities (logits). The logits of the filtered tokens are set to filter_value.

Parameters
  • top_k – The number of highest probability tokens to keep. Setting top_k=0 disables the filter.

  • filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.

  • min_tokens_to_keep – Minimum number of tokens to retain, overriding top_k if top_k is smaller. Ensures at least this many tokens are considered. Defaults to 1.

filter_value: float = -inf#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

min_tokens_to_keep: int = 1#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_k: int#
class easydel.inference.logits_process.TopPLogitsWarper(top_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#

Bases: LogitsWarper

[LogitsWarper] that implements top-p (nucleus) sampling.

Filters the vocabulary distribution by keeping only the smallest set of tokens whose cumulative probability mass exceeds the threshold top_p. The logits of the filtered tokens are set to filter_value.

Reference: [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) by Holtzman et al. (2019).

Parameters
  • top_p – The cumulative probability threshold. Must be in (0, 1]. Setting top_p=1.0 disables the filter.

  • filter_value – The value assigned to the logits of filtered tokens. Defaults to -infinity.

  • min_tokens_to_keep – Minimum number of tokens to retain, even if their cumulative probability exceeds top_p. Defaults to 1.

filter_value: float = -inf#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

min_tokens_to_keep: int = 1#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_p: float#
class easydel.inference.logits_process.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#

Bases: LogitsProcessor

A specialized [LogitsProcessor] tailored for handling timestamp tokens during generation with Whisper-style models used for Automatic Speech Recognition (ASR).

It enforces several constraints specific to timestamp prediction: 1. Suppresses `<|notimestamps|>`: Prevents the model from predicting the token

that indicates the absence of timestamps.

  1. Alternating Tokens: Enforces that text tokens and timestamp tokens generally alternate. If the last generated token was a timestamp, it biases against predicting another timestamp immediately after (unless it’s the very beginning or certain edge cases).

  2. Initial Timestamp Limit: Restricts the maximum value of the first timestamp token predicted using max_initial_timestamp_index.

  3. Timestamp Probability Check: If the total probability mass assigned to all valid timestamp tokens is higher than the probability of the single most likely non-timestamp token, it forces the model to sample a timestamp token by suppressing all non-timestamp tokens.

Note

This processor assumes the existence of specific token IDs related to timestamps (e.g., eos_token_id, no_timestamps_token_id, timestamp_begin) which are typically defined in the model’s generation configuration.

Parameters
  • generate_config – Configuration object containing Whisper-specific generation parameters like eos_token_id, no_timestamps_token_id, is_multilingual, max_initial_timestamp_index.

  • model_config – The model’s configuration (used for vocab_size as a fallback).

  • decoder_input_length – The length of the initial input sequence provided to the decoder (e.g., the prompt length).

easydel.inference.logits_process.add_start_docstrings(*docstr)[source]#

A decorator that prepends a given docstring section to the decorated function’s docstring.

This is useful for adding standard documentation sections (like parameter descriptions) to multiple functions without repetition.

Parameters

*docstr – One or more strings that will be joined and prepended to the decorated function’s existing docstring.

Returns

A decorator function.