easydel.infra.mixins.generation

Contents

easydel.infra.mixins.generation#

Generation mixin for text generation capabilities.

Provides text generation functionality through the EasyGenerationMixin class, which can be combined with EasyDeL models to enable various generation strategies including greedy search, sampling, beam search, and more.

Classes:

GreedyState: State container for greedy generation SampleState: State container for sampling generation BeamSearchState: State container for beam search EasyGenerationMixin: Mixin class providing generation methods

Key Features:
  • Multiple generation strategies (greedy, sampling, beam search)

  • Logits processing and warping

  • Support for generation constraints

  • Integration with HuggingFace generation configs

  • Efficient JAX implementations

Example

>>> from easydel.infra.mixins import EasyGenerationMixin
>>> # Model class inherits from EasyGenerationMixin
>>> output = model.generate(
...     input_ids=input_ids,
...     max_length=100,
...     temperature=0.8,
...     top_p=0.95,
...     do_sample=True
... )
class easydel.infra.mixins.generation.BeamSearchState(cur_len: Union[Array, ndarray, bool, number], running_sequences: Union[Array, ndarray, bool, number], running_scores: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], scores: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#

Bases: object

State for beam search generation.

cur_len#

Current length of the generated sequence.

Type

chex.Array

running_sequences#

Generated sequences being tracked in the beam.

Type

chex.Array

running_scores#

Scores of the sequences being tracked in the beam.

Type

chex.Array

sequences#

Best generated sequences.

Type

chex.Array

scores#

Scores of the best generated sequences.

Type

chex.Array

is_sent_finished#

Boolean array indicating if a sequence is finished.

Type

chex.Array

model_kwargs#

Model specific keyword arguments.

Type

tp.Dict[str, chex.Array]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_scores: Union[Array, ndarray, bool, number]#
running_sequences: Union[Array, ndarray, bool, number]#
scores: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.mixins.generation.EasyGenerationMixin[source]#

Bases: object

base_model_prefix: str#
static compute_prefill_length(array, padding_id) Union[Array, ndarray, bool, number][source]#

Calculates the number of padding tokens at the beginning of each sequence.

This is useful for determining the actual starting position in a KV cache when dealing with left-padded inputs.

Parameters
  • array (chex.Array) – The input token ID array, typically shape (batch_size, sequence_length).

  • padding_id (int) – The token ID used for padding.

Returns

An array of shape (batch_size,) containing the number of leading

padding tokens for each sequence in the batch.

Return type

chex.Array

static compute_prefill_length_from_mask(mask) Union[Array, ndarray, bool, number][source]#

Calculates the number of padding tokens at the beginning of each sequence from a 0/1 or boolean mask.

config: EasyDeLBaseConfig#
config_class: type[easydel.infra.base_config.EasyDeLBaseConfig]#
create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None) TransformerCacheMetaData[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.

Parameters
  • batch_size (int) – The batch size for which the cache is being configured.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.

Returns

An initialized metadata object for a standard KV cache.

Return type

TransformerCacheMetaData

create_paged_metadata(hbm_utilization: float, page_size: int, max_model_length: int) RaggedPagesCacheMetaData[source]#

Creates the static configuration metadata required for initializing a Paged KV Cache.

This method gathers necessary parameters from the model’s configuration (like number of layers, heads, dimensions) and combines them with the provided arguments to instantiate and return a RaggedPagesCacheMetaData object. This metadata object defines the structure and allocation parameters for the paged cache.

Returns

An initialized metadata object containing the

static configuration for the paged cache.

Return type

RaggedPagesCacheMetaData

property esurge_compatible_model#

Returns a model instance compatible with eSurge inference engine.

eSurge requires models to use ragged page attention mechanisms (v2 or v3). If the current model uses a different attention mechanism, this property returns a new model instance with ragged_page_attention_v3 while preserving all parameters and state.

Returns

Model instance with eSurge-compatible attention mechanism.

Return type

Self

Note

If the model already uses a compatible attention mechanism, returns self. Otherwise, builds a new graph definition with ragged_page_attention_v3 and merges the existing parameters/state, leaving the original model unchanged.

Example

>>> # Get eSurge-compatible version of model
>>> esurge_model = model.esurge_compatible_model
>>> # Now safe to use with eSurge inference
>>> outputs = esurge_model.esurge_generate("Hello world")
esurge_generate(prompts: list[dict[str, str]] | list[str] | str, tools: list[dict] | None = None, sampling_params: SamplingParams | None = None, request_id: str | None = None, stream: bool = False, chat_template: str | None = None, *, tokenizer: str | PreTrainedTokenizerBase | None = None, max_model_len: int | None = None, min_input_pad: int | None = None, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, hbm_utilization: float | None = None, page_size: int | None = None, enable_prefix_caching: bool | None = None, runner_verbose: bool | None = None, decode_truncated_prompt: bool | None = None, destroy_pages_on_pause: bool | None = None, silent_mode: bool | None = None, use_tqdm: bool = False)[source]#

High-level interface for text generation using eSurge engine.

This method provides a convenient way to generate text using the eSurge inference engine with automatic caching and configuration management. It supports both chat and completion modes, with optional streaming.

All engine configuration parameters are optional. When omitted, the method will: 1. Try to retrieve values from a cached engine for this model 2. Fall back to sensible defaults if no cached engine exists 3. Only require tokenizer on the first call when no cache exists

Parameters
  • prompts – Input prompts. Can be: - Single string for simple completion - List of strings for batch completion - List of dicts with ‘role’ and ‘content’ keys for chat mode

  • tools – Optional list of tool/function definitions for function calling in chat mode.

  • sampling_params – Generation parameters (temperature, top_p, max_tokens, etc.). Defaults to SamplingParams(max_tokens=128) if None.

  • request_id – Optional unique identifier for tracking. Auto-generated if None.

  • stream – If True, returns an iterator for streaming generation. If False, returns complete results.

  • chat_template – Optional custom Jinja2 template for chat formatting.

  • tokenizer – Tokenizer path or instance. Required only on first call if no cached engine exists. Subsequent calls can omit this to reuse the cached tokenizer.

  • max_model_len – Maximum sequence length. Defaults to model’s max position embeddings.

  • min_input_pad – Minimum padding for input sequences. Defaults to 16.

  • max_num_seqs – Maximum number of concurrent sequences. Defaults to 32.

  • max_num_batched_tokens – Maximum tokens per batch. Defaults to None (auto-calculate).

  • hbm_utilization – Fraction of HBM to use for KV cache. Defaults to 0.85.

  • page_size – Size of memory pages for paged attention. Defaults to 128.

  • enable_prefix_caching – Enable prefix caching for shared prompts. Defaults to True.

  • runner_verbose – Enable verbose logging in the model runner. Defaults to False.

  • decode_truncated_prompt – Decode and display truncated prompts. Defaults to True.

  • destroy_pages_on_pause – Free memory pages when requests are paused. Defaults to True.

Returns

  • If stream=True: Iterator[RequestOutput] with delta updates
    • If stream=False: RequestOutput with complete response

  • For completion mode (prompts is str or list[str]):
    • If stream=True: Iterator[RequestOutput] with delta updates

    • If stream=False: list[RequestOutput] with complete responses

Return type

  • For chat mode (prompts is list[dict])

Example

>>> # Simple completion
>>> outputs = model.esurge_generate("Tell me about AI")
>>> print(outputs[0].get_text())
>>>
>>> # Streaming completion
>>> for chunk in model.esurge_generate("Tell me a story", stream=True):
...     print(chunk.delta_text, end="", flush=True)
>>>
>>> # Chat mode
>>> messages = [
...     {"role": "system", "content": "You are helpful."},
...     {"role": "user", "content": "What is 2+2?"}
... ]
>>> response = model.esurge_generate(messages)
>>> print(response.get_text())
property esurge_graphdef#

Returns a graph definition compatible with eSurge inference engine.

eSurge requires models to use ragged page attention mechanisms (v2 or v3). If the current model uses a different attention mechanism, this property creates a new graph definition with ragged_page_attention_v3.

Returns

Graph definition with eSurge-compatible attention mechanism.

Return type

nn.GraphDef

Note

This creates only the graph structure, not a complete model. Use esurge_compatible_model if you need a full model instance.

Example

>>> gdef = model.esurge_graphdef
>>> # Use gdef for creating eSurge-compatible model instances
generate(input_ids: Union[Array, ndarray, bool, number], generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: easydel.inference.logits_process.LogitsProcessorList | None = None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

get_esurge(tokenizer: str | PreTrainedTokenizerBase | None = None, max_model_len: int | None = None, min_input_pad: int | None = None, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, hbm_utilization: float | None = None, page_size: int | None = None, enable_prefix_caching: bool | None = None, runner_verbose: bool | None = None, decode_truncated_prompt: bool | None = None, destroy_pages_on_pause: bool | None = None, silent_mode: bool | None = None)[source]#

Gets or creates an eSurge engine with the specified parameters.

This method intelligently retrieves an existing cached engine or creates a new one. For any parameter that is None, it will: 1. Try to retrieve the value from a cached engine 2. If no cached engine exists, use sensible defaults 3. Only require tokenizer if no cached engine is available

Parameters
  • tokenizer – Tokenizer path or instance. If None, retrieves from cached engine. Required only if no cached engine exists.

  • max_model_len – Maximum sequence length. Defaults to model’s max position embeddings.

  • min_input_pad – Minimum padding for input sequences. Defaults to 16.

  • max_num_seqs – Maximum number of concurrent sequences. Defaults to 32.

  • max_num_batched_tokens – Maximum tokens per batch. Defaults to None.

  • hbm_utilization – Fraction of HBM to use for KV cache. Defaults to 0.85.

  • page_size – Size of memory pages for paged attention. Defaults to 128.

  • enable_prefix_caching – Enable prefix caching. Defaults to True.

  • runner_verbose – Enable verbose logging. Defaults to False.

  • decode_truncated_prompt – Decode truncated prompts. Defaults to True.

  • destroy_pages_on_pause – Free memory on pause. Defaults to True.

Returns

eSurge engine instance, either from cache or newly created.

Raises

ValueError – If tokenizer is required but not provided and no cached engine exists.

Example

>>> # First call with tokenizer (creates new engine)
>>> engine = model.get_esurge(tokenizer="meta-llama/Llama-2-7b-hf")
>>>
>>> # Subsequent calls without parameters (reuses cached engine)
>>> engine = model.get_esurge()
>>>
>>> # Override specific parameters
>>> engine = model.get_esurge(max_num_seqs=128)
get_relevant_esurge(tokenizer: str | PreTrainedTokenizerBase | None = None, max_num_seqs: int | None = None)[source]#

Retrieves a relevant eSurge engine instance from the cache.

This method searches for an existing eSurge engine in the cache that matches the current model. If tokenizer or max_num_seqs are None, it returns the most recently created engine for this model. If no engine exists and parameters are missing, it uses sensible defaults.

Parameters
  • tokenizer – Optional tokenizer path or instance. If None, retrieves from the most recent cached engine for this model.

  • max_num_seqs – Optional maximum number of concurrent sequences. If None, uses value from cached engine or defaults to 32 if no cache exists.

Returns

eSurge engine instance if found in cache, None otherwise.

Example

>>> # Try to get existing engine with default params
>>> engine = model.get_relevant_esurge()
>>> if engine:
...     outputs = engine.generate("Hello world")
>>>
>>> # Get engine with specific tokenizer
>>> engine = model.get_relevant_esurge(tokenizer="gpt2")
init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None) TransformerCache[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

init_ragged_pages(metadata: easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData | None = None, page_size: int | None = None, hbm_utilization: float | None = None, max_model_length: int | None = None) RaggedPagesCache[source]#

Initializes and returns the actual Paged Attention KV Cache tensors.

This method orchestrates the creation of the RaggedPagesCache. It either uses a pre-existing RaggedPagesCacheMetaData object passed via the metadata argument, or if metadata is None, it first creates the metadata by calling self.create_paged_metadata using the other provided arguments (page_size, batch_size, etc.).

Finally, it calls RaggedPagesCache.init_cache to allocate the necessary paged tensors (key_pages, value_pages for each layer) based on the metadata, model’s mesh, dtype, partition manager, and quantization settings.

Parameters
  • metadata (tp.Optional[RaggedPagesCacheMetaData]) – An optional pre-configured metadata object. If provided, other arguments like page_size, batch_size etc., are ignored for metadata creation.

  • page_size (tp.Optional[int]) – Number of tokens per page. Required if metadata is None.

  • hbm_utilization (tp.Optional[float]) – Target HBM usage. Required if metadata is None.

Returns

An initialized RaggedPagesCache object containing the allocated

cache tensors (views) for all layers.

Return type

RaggedPagesCache

Raises

AssertionError – If metadata is None and any of the required arguments (page_size, batch_size, max_sequences, dtype, hbm_utilization) are also None.

list_esurge_engines() list[dict][source]#

List all cached eSurge engines for this model.

Returns a list of dictionaries containing information about each cached engine, including its status (running/paused), number of requests, and configuration hash.

Returns

  • cache_key: The cache key for this engine

  • paused: Whether the engine is paused

  • running_requests: Number of currently running requests

  • pending_requests: Number of pending requests

  • max_num_seqs: Maximum concurrent sequences

Return type

List of dicts with engine information

Example

>>> engines = model.list_esurge_engines()
>>> for engine in engines:
...     print(f"Engine {engine['cache_key']}: "
...           f"Paused={engine['paused']}, "
...           f"Running={engine['running_requests']}")
pause_esurge(engine_id: str | None = None) None[source]#

Pause eSurge engine(s) for this model.

Pauses the background scheduler of eSurge engines without clearing queued state. This is useful for temporarily freeing resources while keeping the engine ready for quick resumption.

Parameters

engine_id – Optional specific engine cache key to pause. If None, pauses all engines associated with this model.

Example

>>> # Pause all engines for this model
>>> model.pause_esurge()
>>>
>>> # Later, generate will auto-resume
>>> outputs = model.esurge_generate("prompt")  # Auto-resumes!
prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: jax.Array | None = None, token_type_ids: jax.Array | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None) dict[str, Any][source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

resume_esurge(engine_id: str | None = None) None[source]#

Resume paused eSurge engine(s) for this model.

Resumes the background scheduler of paused eSurge engines, making them ready to process generation requests again.

Parameters

engine_id – Optional specific engine cache key to resume. If None, resumes all engines associated with this model.

Example

>>> # Pause engines to free resources
>>> model.pause_esurge()
>>>
>>> # Manually resume when needed
>>> model.resume_esurge()
>>> outputs = model.esurge_generate("prompt")
update_inputs_for_generation(model_outputs, model_kwargs) dict[str, Any][source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.infra.mixins.generation.GreedyState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#

Bases: object

State container for greedy search generation.

Tracks the current state during greedy decoding, where the most probable token is selected at each step.

cur_len#

Current length of generated sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

sequences#

Generated token sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

running_token#

Currently processed token.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

is_sent_finished#

Boolean flags for finished sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

model_kwargs#

Additional model-specific arguments.

Type

dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.mixins.generation.SampleState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], prng_key: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#

Bases: object

State container for sampling-based generation.

Tracks the current state during sampling generation, where tokens are sampled from the probability distribution.

cur_len#

Current length of generated sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

sequences#

Generated token sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

running_token#

Currently processed token.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

is_sent_finished#

Boolean flags for finished sequences.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

prng_key#

JAX PRNG key for random sampling.

Type

Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]

model_kwargs#

Additional model-specific arguments.

Type

dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]

cur_len: Union[Array, ndarray, bool, number]#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_sent_finished: Union[Array, ndarray, bool, number]#
model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
prng_key: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

running_token: Union[Array, ndarray, bool, number]#
sequences: Union[Array, ndarray, bool, number]#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.