easydel.infra.mixins.generation#
- class easydel.infra.mixins.generation.BeamSearchState(cur_len: Union[Array, ndarray, bool, number], running_sequences: Union[Array, ndarray, bool, number], running_scores: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], scores: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#
Bases:
objectState for beam search generation.
- cur_len#
Current length of the generated sequence.
- Type
chex.Array
- running_sequences#
Generated sequences being tracked in the beam.
- Type
chex.Array
- running_scores#
Scores of the sequences being tracked in the beam.
- Type
chex.Array
- sequences#
Best generated sequences.
- Type
chex.Array
- scores#
Scores of the best generated sequences.
- Type
chex.Array
- is_sent_finished#
Boolean array indicating if a sequence is finished.
- Type
chex.Array
- model_kwargs#
Model specific keyword arguments.
- Type
tp.Dict[str, chex.Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.mixins.generation.EasyGenerationMixin[source]#
Bases:
object- base_model_prefix: str#
- static compute_prefill_length(array, padding_id) Union[Array, ndarray, bool, number][source]#
Calculates the number of non-padding tokens at the beginning of each sequence.
This is useful for determining the actual starting position in a KV cache when dealing with left-padded inputs.
- Parameters
array (chex.Array) – The input token ID array, typically shape (batch_size, sequence_length).
padding_id (int) – The token ID used for padding.
- Returns
- An array of shape (batch_size,) containing the number of leading
padding tokens for each sequence in the batch.
- Return type
chex.Array
- config: EasyDeLBaseConfig#
- config_class: Type[EasyDeLBaseConfig]#
- create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None) TransformerCacheMetaData[source]#
Creates the metadata required for initializing a standard (non-paged) KV Cache.
This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.
- Parameters
batch_size (int) – The batch size for which the cache is being configured.
max_length (int) – The maximum sequence length the cache needs to support.
pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.
- Returns
An initialized metadata object for a standard KV cache.
- Return type
- create_paged_metadata(page_size: int, batch_size: int, max_sequences: int, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, hbm_utilization: float = 0.875) PagedAttentionCacheMetaData[source]#
Creates the static configuration metadata required for initializing a Paged KV Cache.
This method gathers necessary parameters from the model’s configuration (like number of layers, heads, dimensions) and combines them with the provided arguments to instantiate and return a PagedAttentionCacheMetaData object. This metadata object defines the structure and allocation parameters for the paged cache.
- Parameters
page_size (int) – The number of tokens to store per cache page.
batch_size (int) – The maximum number of sequences to handle concurrently during the decode phase.
max_sequences (int) – The maximum sequence length the cache should be configured to support.
dtype (jnp.dtype) – The data type to assume for cache memory calculation. Defaults to jnp.bfloat16.
hbm_utilization (float) – The target fraction of High Bandwidth Memory (HBM) to allocate for the KV cache pages. Defaults to 0.875 (87.5%).
- Returns
- An initialized metadata object containing the
static configuration for the paged cache.
- Return type
- generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Optional[GenerationConfig] = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Optional[LogitsProcessorList] = None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None) TransformerCache[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- init_pages(metadata: ~typing.Optional[~easydel.layers.caching.paged_attention.paged_attention_cache.PagedAttentionCacheMetaData] = None, page_size: ~typing.Optional[int] = None, batch_size: ~typing.Optional[int] = None, max_sequences: ~typing.Optional[int] = None, dtype: ~typing.Optional[~numpy.dtype] = <class 'jax.numpy.bfloat16'>, hbm_utilization: ~typing.Optional[float] = None) PagedAttentionCache[source]#
Initializes and returns the actual Paged Attention KV Cache tensors.
This method orchestrates the creation of the PagedAttentionCache. It either uses a pre-existing PagedAttentionCacheMetaData object passed via the metadata argument, or if metadata is None, it first creates the metadata by calling self.create_paged_metadata using the other provided arguments (page_size, batch_size, etc.).
Finally, it calls PagedAttentionCache.init_cache to allocate the necessary paged tensors (key_pages, value_pages for each layer) based on the metadata, model’s mesh, dtype, partition manager, and quantization settings.
- Parameters
metadata (tp.Optional[PagedAttentionCacheMetaData]) – An optional pre-configured metadata object. If provided, other arguments like page_size, batch_size etc., are ignored for metadata creation.
page_size (tp.Optional[int]) – Number of tokens per page. Required if metadata is None.
batch_size (tp.Optional[int]) – Max concurrent sequences for decode. Required if metadata is None.
max_sequences (tp.Optional[int]) – Max supported sequence length. Required if metadata is None.
dtype (tp.Optional[jnp.dtype]) – Data type for memory calculation. Required if metadata is None. Defaults to jnp.bfloat16.
hbm_utilization (tp.Optional[float]) – Target HBM usage. Required if metadata is None.
- Returns
- An initialized PagedAttentionCache object containing the allocated
cache tensors (views) for all layers.
- Return type
- Raises
AssertionError – If metadata is None and any of the required arguments (page_size, batch_size, max_sequences, dtype, hbm_utilization) are also None.
- prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Dict[str, Any][source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs) Dict[str, Any][source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.infra.mixins.generation.GreedyState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#
Bases:
objectState for greedy search generation.
- cur_len#
Current length of the generated sequence.
- Type
chex.Array
- sequences#
Generated sequences so far.
- Type
chex.Array
- running_token#
Current token being processed.
- Type
chex.Array
- is_sent_finished#
Boolean array indicating if a sequence is finished.
- Type
chex.Array
- model_kwargs#
Model specific keyword arguments.
- Type
tp.Dict[str, chex.Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.mixins.generation.SampleState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], prng_key: Union[Array, ndarray, bool, number], model_kwargs: Dict[str, Union[Array, ndarray, bool, number]])[source]#
Bases:
objectState for sampling generation.
- cur_len#
Current length of the generated sequence.
- Type
chex.Array
- sequences#
Generated sequences so far.
- Type
chex.Array
- running_token#
Current token being processed.
- Type
chex.Array
- is_sent_finished#
Boolean array indicating if a sequence is finished.
- Type
chex.Array
- prng_key#
PRNG key for sampling.
- Type
chex.Array
- model_kwargs#
Model specific keyword arguments.
- Type
tp.Dict[str, chex.Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.