easydel.infra.mixins.__init__

Contents

easydel.infra.mixins.__init__#

class easydel.infra.mixins.__init__.BaseModuleProtocol[source]#

Bases: object

Protocol defining the common interface for EasyDeL modules.

abstract apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Apply LoRA (Low-Rank Adaptation) to specified linear layers within a model.

base_model_prefix: str#
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[CausalLMOutput, LossMetrics]#
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[SequenceClassifierOutput, LossMetrics]
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeModelOutput, LossMetrics]
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeCausalLMOutput, LossMetrics]
compute_loss(*, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch) tp.Tuple[tp.Any, LossMetrics]

Helper for @overload to raise when called.

config: EasyDeLBaseConfig#
config_class: Type[EasyDeLBaseConfig]#
abstract float(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float32.

abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#

Gathers the modelโ€™s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) โ€“ The mesh to gather from.

Returns

The gathered model.

Return type

nn.Module

abstract get_static_arguments() Tuple[source]#

return static arguments kwargs for jax.jit

abstract property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
abstract property graphother: State[Key, VariableState[Any]]#
abstract property graphstate: State[Key, VariableState[Any]]#
abstract half(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float16.

abstract classmethod lazy_init(*args, **kwargs) SELF[source]#

initialize the base class with nnx.eval_shape carefully

abstract merge_lora_params(pytree: Dict) SELF[source]#

Merge LoRA (Low-Rank Adaptation) parameters into the base model parameters.

abstract merge_params(tree)[source]#

merge state to the current model

abstract merge_params_dict(params_dict: Dict)[source]#

Merges the model parameters from a dictionary into the current model.

abstract property params_sharding: Dict#

return the sharding of the model parameters

abstract prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

abstract property pure_transform_fn: Callable#

generates a pure transform function for converting torch to easydel module.

abstract quantize(method: EasyDeLQuantizationMethods, skip_modules: list[str] | None = None, verbose: bool = False, **kwargs)[source]#

Quantizes the modelโ€™s linear layers.

Parameters
  • method (EasyDeLQuantizationMethods, optional) โ€“ The quantization method to use.

  • skip_modules (list[str] | None, optional) โ€“ List of module names to skip.

  • verbose (bool, optional) โ€“ Whether to print verbose output.

  • **kwargs โ€“ Additional keyword arguments.

Returns

The quantized model.

Return type

nn.Module

abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#

Shards the modelโ€™s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) โ€“ The mesh to shard across.

Returns

The sharded model.

Return type

nn.Module

abstract split_lora_params() dict[source]#

Split LoRA (Low-Rank Adaptation) parameters from the base model parameters.

abstract split_params()[source]#

split the model parameters

abstract split_params_dict(params_dict: Dict) Dict[source]#

Splits the model parameters from a dictionary into separate state components.

abstract to_dtype(dtype) SELF[source]#

Converts Model paramters to given dtype

abstract to_state() Any[source]#

converts current model to a EasyDeLState

abstract to_torch() Any[source]#

converts current model to a huggingface torch model

abstract property transform_fn: Callable#

generate transform function for converting torch to easydel module.

abstract unwrap_lora_to_layers(verbose: bool = False)[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.

class easydel.infra.mixins.__init__.EasyBridgeMixin[source]#

Bases: PushToHubMixin

Mixin class for adding bridging functionalities like saving, loading, and pushing models to Hugging Face Hub.

base_model_prefix: Optional[str] = None#
classmethod can_generate() bool[source]#

Checks if the model can generate sequences with .generate().

Returns

True if the model can generate, False otherwise.

Return type

bool

config: EasyDeLBaseConfig#
config_class: Optional[Type[EasyDeLBaseConfig]] = None#
classmethod from_pretrained(pretrained_model_name_or_path: ~typing.Optional[~typing.Union[str, ~os.PathLike]], sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~eformer.escale.partition.manager.PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis=None, key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis=None, decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, config_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec]]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = 'jax', shard_fns: ~typing.Optional[dict[typing.Callable]] = None, auto_shard_model: bool = False, verbose: bool = True, mismatch_allowed: bool = True, *model_args, config: ~typing.Optional[~typing.Union[~easydel.infra.base_config.EasyDeLBaseConfig, str, ~os.PathLike]] = None, cache_dir: ~typing.Optional[~typing.Union[str, ~os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: ~typing.Optional[~typing.Union[str, bool]] = None, revision: str = 'main', vebose: bool = True, quantization_platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, **kwargs)[source]#

Loads an EasyDeL model from a pretrained model or path.

Parameters
  • pretrained_model_name_or_path (str, optional) โ€“ The name or path of the pretrained model.

  • sharding_axis_dims (Sequence[int], optional) โ€“ The dimensions of sharding axes.

  • sharding_axis_names (Sequence[str], optional) โ€“ The names of sharding axes.

  • partition_axis (PartitionAxis, optional) โ€“ The partition axis configuration.

  • dtype (dtype, optional) โ€“ The data type of the model.

  • param_dtype (dtype, optional) โ€“ The data type of the parameters.

  • precision (PrecisionLike, optional) โ€“ The computation precision.

  • config_kwargs (dict[str, Any], optional) โ€“ Additional configuration parameters.

  • partition_rules (tuple, optional) โ€“ Custom partitioning rules for sharding.

  • backend (EasyDeLBackends, optional) โ€“ The backend to use.

  • platform (EasyDeLPlatforms, optional) โ€“ The platform to use.

  • shard_fns (dict[Callable], optional) โ€“ Custom shard functions for loading checkpoint.

  • auto_shard_model (bool, optional) โ€“ Whether to automatically shard the model.

  • verbose (bool, optional) โ€“ Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) โ€“ If True, allows mismatch in parameters while loading. Defaults to True.

  • *model_args โ€“ Additional arguments for the model.

  • config (str, optional) โ€“ configuration for the model.

  • cache_dir (str, optional) โ€“ The cache directory for the pretrained model.

  • force_download (bool, optional) โ€“ Whether to force download the model.

  • local_files_only (bool, optional) โ€“ Whether to use only local files.

  • token (str, optional) โ€“ The Hugging Face Hub token.

  • revision (str, optional) โ€“ The revision of the model to load.

  • **kwargs โ€“ Additional keyword arguments.

Returns

The loaded EasyDeL model.

classmethod get_torch_loader()[source]#
hf_torch_auto_loader: Optional[Any] = None#
push_to_hub(repo_id: str, use_temp_dir: Optional[bool] = None, commit_message: Optional[str] = None, private: Optional[bool] = None, token: Optional[Union[str, bool]] = None, create_pr: bool = False, gather_fns: Optional[dict[Callable]] = None, float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, revision: Optional[str] = None, commit_description: Optional[str] = None) str[source]#

Pushes the model to the Hugging Face Hub.

Parameters
  • repo_id (str) โ€“ The repository ID on Hugging Face Hub.

  • params (any) โ€“ Model parameters.

  • use_temp_dir (bool, optional) โ€“ If True, uses a temporary directory. Defaults to None

  • commit_message (str, optional) โ€“ The commit message for the push.

  • private (bool, optional) โ€“ If True, creates a private repository.

  • token (str or bool, optional) โ€“ The Hugging Face Hub token.

  • create_pr (bool, optional) โ€“ If True, creates a pull request.

  • gather_fns (dict[Callable], optional) โ€“ Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) โ€“ Data type for saving weights.

  • verbose (bool, optional) โ€“ Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) โ€“ If True, allows mismatch in parameters while loading. Defaults to True.

  • revision (str, optional) โ€“ The revision to push to.

  • commit_description (str, optional) โ€“ The commit description for the push.

Returns

The URL of the created repository.

Return type

str

save_pretrained(save_directory: Union[str, PathLike], push_to_hub: bool = False, token: Optional[Union[str, bool]] = None, gather_fns: Optional[dict[Callable]] = None, float_dtype=None, verbose: bool = True, mismatch_allowed: bool = True, enable: Optional[bool] = None, **kwargs)[source]#

Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.

Parameters
  • save_directory (str or PathLike) โ€“ The directory where to save the model.

  • push_to_hub (bool, optional) โ€“ If True, pushes the model to the Hugging Face Hub.

  • token (str or bool, optional) โ€“ The Hugging Face Hub token.

  • gather_fns (dict[Callable], optional) โ€“ Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) โ€“ Data type for saving weights.

  • verbose (bool, optional) โ€“ Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) โ€“ If True, allows mismatch in parameters while loading. Defaults to True.

  • enable (bool) โ€“ if True, allows file to be saved (used for multi-host saving models).

  • **kwargs โ€“ Additional keyword arguments for Hugging Face Hub.

class easydel.infra.mixins.__init__.EasyGenerationMixin[source]#

Bases: object

base_model_prefix: str#
static compute_prefill_length(array, padding_id) Union[Array, ndarray, bool, number][source]#

Calculates the number of non-padding tokens at the beginning of each sequence.

This is useful for determining the actual starting position in a KV cache when dealing with left-padded inputs.

Parameters
  • array (chex.Array) โ€“ The input token ID array, typically shape (batch_size, sequence_length).

  • padding_id (int) โ€“ The token ID used for padding.

Returns

An array of shape (batch_size,) containing the number of leading

padding tokens for each sequence in the batch.

Return type

chex.Array

config: EasyDeLBaseConfig#
config_class: Type[EasyDeLBaseConfig]#
create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None) TransformerCacheMetaData[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.

Parameters
  • batch_size (int) โ€“ The batch size for which the cache is being configured.

  • max_length (int) โ€“ The maximum sequence length the cache needs to support.

  • pad_token_id (int | None) โ€“ The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.

Returns

An initialized metadata object for a standard KV cache.

Return type

TransformerCacheMetaData

create_paged_metadata(page_size: int, batch_size: int, max_sequences: int, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, hbm_utilization: float = 0.875) PagedAttentionCacheMetaData[source]#

Creates the static configuration metadata required for initializing a Paged KV Cache.

This method gathers necessary parameters from the modelโ€™s configuration (like number of layers, heads, dimensions) and combines them with the provided arguments to instantiate and return a PagedAttentionCacheMetaData object. This metadata object defines the structure and allocation parameters for the paged cache.

Parameters
  • page_size (int) โ€“ The number of tokens to store per cache page.

  • batch_size (int) โ€“ The maximum number of sequences to handle concurrently during the decode phase.

  • max_sequences (int) โ€“ The maximum sequence length the cache should be configured to support.

  • dtype (jnp.dtype) โ€“ The data type to assume for cache memory calculation. Defaults to jnp.bfloat16.

  • hbm_utilization (float) โ€“ The target fraction of High Bandwidth Memory (HBM) to allocate for the KV cache pages. Defaults to 0.875 (87.5%).

Returns

An initialized metadata object containing the

static configuration for the paged cache.

Return type

PagedAttentionCacheMetaData

generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Optional[GenerationConfig] = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Optional[LogitsProcessorList] = None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) โ€“ The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) โ€“ The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]โ€™s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) โ€“ Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`LogitsProcessorList `, optional) โ€“ Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) โ€“ Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None) TransformerCache[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the modelโ€™s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) โ€“ The batch size for the cache.

  • max_length (int) โ€“ The maximum sequence length the cache needs to support.

  • starts (int | None) โ€“ Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) โ€“ Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) โ€“ The ID of the padding token. If None, itโ€™s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

init_pages(metadata: ~typing.Optional[~easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData] = None, page_size: ~typing.Optional[int] = None, batch_size: ~typing.Optional[int] = None, max_sequences: ~typing.Optional[int] = None, dtype: ~typing.Optional[~numpy.dtype] = <class 'jax.numpy.bfloat16'>, hbm_utilization: ~typing.Optional[float] = None) PagedAttentionCache[source]#

Initializes and returns the actual Paged Attention KV Cache tensors.

This method orchestrates the creation of the PagedAttentionCache. It either uses a pre-existing PagedAttentionCacheMetaData object passed via the metadata argument, or if metadata is None, it first creates the metadata by calling self.create_paged_metadata using the other provided arguments (page_size, batch_size, etc.).

Finally, it calls PagedAttentionCache.init_cache to allocate the necessary paged tensors (key_pages, value_pages for each layer) based on the metadata, modelโ€™s mesh, dtype, partition manager, and quantization settings.

Parameters
  • metadata (tp.Optional[PagedAttentionCacheMetaData]) โ€“ An optional pre-configured metadata object. If provided, other arguments like page_size, batch_size etc., are ignored for metadata creation.

  • page_size (tp.Optional[int]) โ€“ Number of tokens per page. Required if metadata is None.

  • batch_size (tp.Optional[int]) โ€“ Max concurrent sequences for decode. Required if metadata is None.

  • max_sequences (tp.Optional[int]) โ€“ Max supported sequence length. Required if metadata is None.

  • dtype (tp.Optional[jnp.dtype]) โ€“ Data type for memory calculation. Required if metadata is None. Defaults to jnp.bfloat16.

  • hbm_utilization (tp.Optional[float]) โ€“ Target HBM usage. Required if metadata is None.

Returns

An initialized PagedAttentionCache object containing the allocated

cache tensors (views) for all layers.

Return type

PagedAttentionCache

Raises

AssertionError โ€“ If metadata is None and any of the required arguments (page_size, batch_size, max_sequences, dtype, hbm_utilization) are also None.

prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Dict[str, Any][source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) โ€“ The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) โ€“ The maximum sequence length that the KV cache should support.

  • pad_token_id (int) โ€“ The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) โ€“ Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) โ€“ Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) โ€“ An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) โ€“ Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • โ€past_key_valuesโ€: The initialized KV cache.

  • โ€attention_maskโ€: The extended attention mask for generation.

  • โ€position_idsโ€: The calculated initial position IDs.

  • โ€token_type_idsโ€: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs) Dict[str, Any][source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs โ€“ The output object from the modelโ€™s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) โ€“ The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict