easydel.infra.mixins.generation

Contents

easydel.infra.mixins.generation#

class easydel.infra.mixins.generation.BeamSearchState(cur_len: Union[Array, ndarray, bool, number], running_sequences: Union[Array, ndarray, bool, number], running_scores: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], scores: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: object

State for beam search generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

running_sequences#

Generated sequences being tracked in the beam.

Type

chex.Array

running_scores#

Scores of the sequences being tracked in the beam.

Type

chex.Array

sequences#

Best generated sequences.

Type

chex.Array

scores#

Scores of the best generated sequences.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_scores: Union[Array, ndarray, bool, number]#
running_sequences: Union[Array, ndarray, bool, number]#
scores: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.mixins.generation.EasyGenerationMixin[source]#

Bases: object

base_model_prefix: str#
static compute_prefill_length(array, padding_id) Union[Array, ndarray, bool, number][source]#

Calculates the number of non-padding tokens at the beginning of each sequence.

This is useful for determining the actual starting position in a KV cache when dealing with left-padded inputs.

Parameters
  • array (chex.Array) – The input token ID array, typically shape (batch_size, sequence_length).

  • padding_id (int) – The token ID used for padding.

Returns

An array of shape (batch_size,) containing the number of leading

padding tokens for each sequence in the batch.

Return type

chex.Array

config: EasyDeLBaseConfig#
config_class: Type[EasyDeLBaseConfig]#
create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None) TransformerCacheMetaData[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.

Parameters
  • batch_size (int) – The batch size for which the cache is being configured.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.

Returns

An initialized metadata object for a standard KV cache.

Return type

TransformerCacheMetaData

create_paged_metadata(page_size: int, batch_size: int, max_sequences: int, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, hbm_utilization: float = 0.875) PagedAttentionCacheMetaData[source]#

Creates the static configuration metadata required for initializing a Paged KV Cache.

This method gathers necessary parameters from the model’s configuration (like number of layers, heads, dimensions) and combines them with the provided arguments to instantiate and return a PagedAttentionCacheMetaData object. This metadata object defines the structure and allocation parameters for the paged cache.

Parameters
  • page_size (int) – The number of tokens to store per cache page.

  • batch_size (int) – The maximum number of sequences to handle concurrently during the decode phase.

  • max_sequences (int) – The maximum sequence length the cache should be configured to support.

  • dtype (jnp.dtype) – The data type to assume for cache memory calculation. Defaults to jnp.bfloat16.

  • hbm_utilization (float) – The target fraction of High Bandwidth Memory (HBM) to allocate for the KV cache pages. Defaults to 0.875 (87.5%).

Returns

An initialized metadata object containing the

static configuration for the paged cache.

Return type

PagedAttentionCacheMetaData

generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Optional[GenerationConfig] = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Optional[LogitsProcessorList] = None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None) TransformerCache[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

init_pages(metadata: ~typing.Optional[~easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData] = None, page_size: ~typing.Optional[int] = None, batch_size: ~typing.Optional[int] = None, max_sequences: ~typing.Optional[int] = None, dtype: ~typing.Optional[~numpy.dtype] = <class 'jax.numpy.bfloat16'>, hbm_utilization: ~typing.Optional[float] = None) PagedAttentionCache[source]#

Initializes and returns the actual Paged Attention KV Cache tensors.

This method orchestrates the creation of the PagedAttentionCache. It either uses a pre-existing PagedAttentionCacheMetaData object passed via the metadata argument, or if metadata is None, it first creates the metadata by calling self.create_paged_metadata using the other provided arguments (page_size, batch_size, etc.).

Finally, it calls PagedAttentionCache.init_cache to allocate the necessary paged tensors (key_pages, value_pages for each layer) based on the metadata, model’s mesh, dtype, partition manager, and quantization settings.

Parameters
  • metadata (tp.Optional[PagedAttentionCacheMetaData]) – An optional pre-configured metadata object. If provided, other arguments like page_size, batch_size etc., are ignored for metadata creation.

  • page_size (tp.Optional[int]) – Number of tokens per page. Required if metadata is None.

  • batch_size (tp.Optional[int]) – Max concurrent sequences for decode. Required if metadata is None.

  • max_sequences (tp.Optional[int]) – Max supported sequence length. Required if metadata is None.

  • dtype (tp.Optional[jnp.dtype]) – Data type for memory calculation. Required if metadata is None. Defaults to jnp.bfloat16.

  • hbm_utilization (tp.Optional[float]) – Target HBM usage. Required if metadata is None.

Returns

An initialized PagedAttentionCache object containing the allocated

cache tensors (views) for all layers.

Return type

PagedAttentionCache

Raises

AssertionError – If metadata is None and any of the required arguments (page_size, batch_size, max_sequences, dtype, hbm_utilization) are also None.

prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Dict[str, Any][source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs) Dict[str, Any][source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.infra.mixins.generation.GreedyState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: object

State for greedy search generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

sequences#

Generated sequences so far.

Type

chex.Array

running_token#

Current token being processed.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.mixins.generation.SampleState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], prng_key: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#

Bases: object

State for sampling generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

sequences#

Generated sequences so far.

Type

chex.Array

running_token#

Current token being processed.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

prng_key#

PRNG key for sampling.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: Dict[str, Union[Array, ndarray, bool, number]]#
prng_key: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.