easydel.inference.logits_process#
- class easydel.inference.logits_process.FlaxForceTokensLogitsProcessor(force_token_map)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] that takes a list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens to -inf so that they are sampled at their corresponding index.
- Parameters
force_token_map (list) – Map giving token ids and indices where they will be forced to be sampled.
- class easydel.inference.logits_process.FlaxForcedBOSTokenLogitsProcessor(bos_token_id: int)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] that enforces the specified token as the first generated token.
- Parameters
bos_token_id (int) – The id of the token to force as the first generated token.
- class easydel.inference.logits_process.FlaxForcedEOSTokenLogitsProcessor(max_length: int, eos_token_id: int)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] that enforces the specified token as the last generated token when max_length is reached.
- Parameters
max_length (int) – The maximum length of the sequence to be generated.
eos_token_id (int) – The id of the token to force as the last generated token when max_length is reached.
- class easydel.inference.logits_process.FlaxLogitsProcessor[source]#
Bases:
objectAbstract base class for all logit processors that can be applied during generation.
- class easydel.inference.logits_process.FlaxLogitsProcessorList(iterable=(), /)[source]#
Bases:
listThis class can be used to create a list of [FlaxLogitsProcessor] or [FlaxLogitsWarper] to subsequently process a scores input tensor. This class inherits from list and adds a specific __call__ method to apply each [FlaxLogitsProcessor] or [FlaxLogitsWarper] to the inputs.
- class easydel.inference.logits_process.FlaxLogitsWarper[source]#
Bases:
objectAbstract base class for all logit warpers that can be applied during generation with multinomial sampling.
- class easydel.inference.logits_process.FlaxMinLengthLogitsProcessor(min_length: int, eos_token_id: int)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] enforcing a min-length by setting EOS probability to 0.
- Parameters
min_length (int) – The minimum length below which the score of eos_token_id is set to -float(“Inf”).
eos_token_id (int) – The id of the end-of-sequence token.
- class easydel.inference.logits_process.FlaxNoRepeatNGramLogitsProcessor(ngram_size: int)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] that enforces no repetition of n-grams. See [Fairseq](pytorch/fairseq).
- Parameters
ngram_size (int) – All ngrams of size ngram_size can only occur once.
- get_banned_tokens_mask(latest_tokens: Array, previous_ngrams) Array[source]#
Determines which tokens must be banned given latest tokens and the previously seen ngrams.
- get_previous_ngrams(input_ids: Array, vocab_size: int, cur_len: int)[source]#
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that represent the n-grams that occurred previously. The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
- class easydel.inference.logits_process.FlaxStaticForceTokensLogitsProcessor(force_token_map)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] that takes a list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens to -inf so that they are sampled at their corresponding index. This is a static version of the transformers logit processor [FlaxForceTokensLogitsProcessor] that is compatible with sharded forced tokens.
- Parameters
force_token_map (list) – Map giving token ids and indices where they will be forced to be sampled.
- class easydel.inference.logits_process.FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] supressing a list of tokens as soon as the generate function starts generating using begin_index tokens. This should ensure that the tokens defined by begin_suppress_tokens are not sampled at the begining of the generation.
- Parameters
begin_suppress_tokens (List[int]) – Tokens to not sample.
begin_index (int) – Index where the tokens are suppressed.
- class easydel.inference.logits_process.FlaxSuppressTokensLogitsProcessor(suppress_tokens: list)[source]#
Bases:
FlaxLogitsProcessor[FlaxLogitsProcessor] suppressing a list of tokens at each decoding step. The processor will set their log probs to be -inf so they are not sampled.
- Parameters
suppress_tokens (list) – Tokens to not sample.
- class easydel.inference.logits_process.FlaxTemperatureLogitsWarper(temperature: float)[source]#
Bases:
FlaxLogitsWarper[FlaxLogitsWarper] for temperature (exponential scaling output probability distribution).
- Parameters
temperature (float) – The value used to module the logits distribution.
- class easydel.inference.logits_process.FlaxTopKLogitsWarper(top_k: int, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#
Bases:
FlaxLogitsWarper[FlaxLogitsWarper] that performs top-k, i.e. restricting to the k highest probability elements.
- Parameters
top_k (int) – The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (float, optional, defaults to -inf) – All filtered values will be set to this float value.
min_tokens_to_keep (int, optional, defaults to 1) – Minimum number of tokens that cannot be filtered.
- class easydel.inference.logits_process.FlaxTopPLogitsWarper(top_p: float, filter_value: float = -inf, min_tokens_to_keep: int = 1)[source]#
Bases:
FlaxLogitsWarper[FlaxLogitsWarper] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
- Parameters
top_p (float) – If set to < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
filter_value (float, optional, defaults to -inf) – All filtered values will be set to this float value.
min_tokens_to_keep (int, optional, defaults to 1) – Minimum number of tokens that cannot be filtered.
- class easydel.inference.logits_process.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#
Bases:
FlaxLogitsProcessorWhisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log probs to inf so that they are sampled at their corresponding index.
- Parameters
generate_config (GenerateConfig) –
- The generate config used to generate the output. The following parameters are required:
- eos_token_id (int, optional, defaults to 50257):
The id of the end-of-sequence token.
- no_timestamps_token_id (int, optional, defaults to 50363):
The id of the “<|notimestamps|>” token.
- max_initial_timestamp_index (int, optional, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future.