easydel.inference.logits_process#

class easydel.inference.logits_process.FlaxForceTokensLogitsProcessor(force_token_map)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] that takes a list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens to -inf so that they are sampled at their corresponding index.

Parameters

force_token_map (list) – Map giving token ids and indices where they will be forced to be sampled.

class easydel.inference.logits_process.FlaxForcedBOSTokenLogitsProcessor(bos_token_id: int)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] that enforces the specified token as the first generated token.

Parameters

bos_token_id (int) – The id of the token to force as the first generated token.

class easydel.inference.logits_process.FlaxForcedEOSTokenLogitsProcessor(max_length: int, eos_token_id: int)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] that enforces the specified token as the last generated token when max_length is reached.

Parameters
  • max_length (int) – The maximum length of the sequence to be generated.

  • eos_token_id (int) – The id of the token to force as the last generated token when max_length is reached.

class easydel.inference.logits_process.FlaxLogitsProcessor[source]#

Bases: object

Abstract base class for all logit processors that can be applied during generation.

class easydel.inference.logits_process.FlaxLogitsProcessorList(iterable=(), /)[source]#

Bases: list

This class can be used to create a list of [FlaxLogitsProcessor] or [FlaxLogitsWarper] to subsequently process a scores input tensor. This class inherits from list and adds a specific __call__ method to apply each [FlaxLogitsProcessor] or [FlaxLogitsWarper] to the inputs.

class easydel.inference.logits_process.FlaxLogitsWarper[source]#

Bases: object

Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.

class easydel.inference.logits_process.FlaxMinLengthLogitsProcessor(min_length: int, eos_token_id: int)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] enforcing a min-length by setting EOS probability to 0.

Parameters
  • min_length (int) – The minimum length below which the score of eos_token_id is set to -float(“Inf”).

  • eos_token_id (int) – The id of the end-of-sequence token.

class easydel.inference.logits_process.FlaxNoRepeatNGramLogitsProcessor(ngram_size: int)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] that enforces no repetition of n-grams. See [Fairseq](pytorch/fairseq).

Parameters

ngram_size (int) – All ngrams of size ngram_size can only occur once.

get_banned_tokens_mask(latest_tokens: Array, previous_ngrams) Array[source]#

Determines which tokens must be banned given latest tokens and the previously seen ngrams.

get_previous_ngrams(input_ids: Array, vocab_size: int, cur_len: int)[source]#

get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that represent the n-grams that occurred previously. The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix

class easydel.inference.logits_process.FlaxStaticForceTokensLogitsProcessor(force_token_map)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] that takes a list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens to -inf so that they are sampled at their corresponding index. This is a static version of the transformers logit processor [FlaxForceTokensLogitsProcessor] that is compatible with sharded forced tokens.

Parameters

force_token_map (list) – Map giving token ids and indices where they will be forced to be sampled.

class easydel.inference.logits_process.FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] supressing a list of tokens as soon as the generate function starts generating using begin_index tokens. This should ensure that the tokens defined by begin_suppress_tokens are not sampled at the begining of the generation.

Parameters
  • begin_suppress_tokens (List[int]) – Tokens to not sample.

  • begin_index (int) – Index where the tokens are suppressed.

class easydel.inference.logits_process.FlaxSuppressTokensLogitsProcessor(suppress_tokens: list)[source]#

Bases: FlaxLogitsProcessor

[FlaxLogitsProcessor] suppressing a list of tokens at each decoding step. The processor will set their log probs to be -inf so they are not sampled.

Parameters

suppress_tokens (list) – Tokens to not sample.

class easydel.inference.logits_process.FlaxTemperatureLogitsWarper(temperature: float)[source]#

Bases: FlaxLogitsWarper

[FlaxLogitsWarper] for temperature (exponential scaling output probability distribution).

Parameters

temperature (float) – The value used to module the logits distribution.

class easydel.inference.logits_process.FlaxTopKLogitsWarper(top_k: int, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#

Bases: FlaxLogitsWarper

[FlaxLogitsWarper] that performs top-k, i.e. restricting to the k highest probability elements.

Parameters
  • top_k (int) – The number of highest probability vocabulary tokens to keep for top-k-filtering.

  • filter_value (float, optional, defaults to -inf) – All filtered values will be set to this float value.

  • min_tokens_to_keep (int, optional, defaults to 1) – Minimum number of tokens that cannot be filtered.

class easydel.inference.logits_process.FlaxTopPLogitsWarper(top_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#

Bases: FlaxLogitsWarper

[FlaxLogitsWarper] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.

Parameters
  • top_p (float) – If set to < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.

  • filter_value (float, optional, defaults to -inf) – All filtered values will be set to this float value.

  • min_tokens_to_keep (int, optional, defaults to 1) – Minimum number of tokens that cannot be filtered.

class easydel.inference.logits_process.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#

Bases: FlaxLogitsProcessor

Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log probs to inf so that they are sampled at their corresponding index.

Parameters

generate_config (GenerateConfig) –

The generate config used to generate the output. The following parameters are required:
eos_token_id (int, optional, defaults to 50257):

The id of the end-of-sequence token.

no_timestamps_token_id (int, optional, defaults to 50363):

The id of the “<|notimestamps|>” token.

max_initial_timestamp_index (int, optional, defaults to 1):

Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future.

easydel.inference.logits_process.add_start_docstrings(*docstr)[source]#