easydel.infra.mixins.generation

Contents

easydel.infra.mixins.generation#

class easydel.infra.mixins.generation.BeamSearchState(cur_len: Union[Array, ndarray, bool, number], running_sequences: Union[Array, ndarray, bool, number], running_scores: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], scores: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: Mapping

State for beam search generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

running_sequences#

Generated sequences being tracked in the beam.

Type

chex.Array

running_scores#

Scores of the sequences being tracked in the beam.

Type

chex.Array

sequences#

Best generated sequences.

Type

chex.Array

scores#

Scores of the best generated sequences.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
from_tuple()#
is_sent_finished: Union[Array, ndarray, bool, number]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
replace(**kwargs)#
running_scores: Union[Array, ndarray, bool, number]#
running_sequences: Union[Array, ndarray, bool, number]#
scores: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_tuple()#
values() an object providing a view on D's values#
class easydel.infra.mixins.generation.EasyGenerationMixin[source]#

Bases: object

base_model_prefix: str#
config: EasyDeLBaseConfig#
config_class: Type[EasyDeLBaseConfig]#
generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Optional[GenerationConfig] = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`FlaxLogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, max_length, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.infra.mixins.generation.GreedyState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: Mapping

State for greedy search generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

sequences#

Generated sequences so far.

Type

chex.Array

running_token#

Current token being processed.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
from_tuple()#
is_sent_finished: Union[Array, ndarray, bool, number]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
replace(**kwargs)#
running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_tuple()#
values() an object providing a view on D's values#
class easydel.infra.mixins.generation.SampleState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], prng_key: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: Mapping

State for sampling generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

sequences#

Generated sequences so far.

Type

chex.Array

running_token#

Current token being processed.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

prng_key#

PRNG key for sampling.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
from_tuple()#
is_sent_finished: Union[Array, ndarray, bool, number]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
prng_key: Union[Array, ndarray, bool, number]#
replace(**kwargs)#
running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_tuple()#
values() an object providing a view on D's values#