easydel.infra.mixins.generation#
Generation mixin for text generation capabilities.
Provides text generation functionality through the EasyGenerationMixin class, which can be combined with EasyDeL models to enable various generation strategies including greedy search, sampling, beam search, and more.
- Classes:
GreedyState: State container for greedy generation SampleState: State container for sampling generation BeamSearchState: State container for beam search EasyGenerationMixin: Mixin class providing generation methods
- Key Features:
Multiple generation strategies (greedy, sampling, beam search)
Logits processing and warping
Support for generation constraints
Integration with HuggingFace generation configs
Efficient JAX implementations
Example
>>> from easydel.infra.mixins import EasyGenerationMixin
>>> # Model class inherits from EasyGenerationMixin
>>> output = model.generate(
... input_ids=input_ids,
... max_length=100,
... temperature=0.8,
... top_p=0.95,
... do_sample=True
... )
- class easydel.infra.mixins.generation.BeamSearchState(cur_len: Union[Array, ndarray, bool, number], running_sequences: Union[Array, ndarray, bool, number], running_scores: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], scores: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#
Bases:
objectState for beam search generation.
- cur_len#
Current length of the generated sequence.
- Type
chex.Array
- running_sequences#
Generated sequences being tracked in the beam.
- Type
chex.Array
- running_scores#
Scores of the sequences being tracked in the beam.
- Type
chex.Array
- sequences#
Best generated sequences.
- Type
chex.Array
- scores#
Scores of the best generated sequences.
- Type
chex.Array
- is_sent_finished#
Boolean array indicating if a sequence is finished.
- Type
chex.Array
- model_kwargs#
Model specific keyword arguments.
- Type
tp.Dict[str, chex.Array]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.mixins.generation.EasyGenerationMixin[source]#
Bases:
object- base_model_prefix: str#
- static compute_prefill_length(array, padding_id) Union[Array, ndarray, bool, number][source]#
Calculates the number of padding tokens at the beginning of each sequence.
This is useful for determining the actual starting position in a KV cache when dealing with left-padded inputs.
- Parameters
array (chex.Array) – The input token ID array, typically shape (batch_size, sequence_length).
padding_id (int) – The token ID used for padding.
- Returns
- An array of shape (batch_size,) containing the number of leading
padding tokens for each sequence in the batch.
- Return type
chex.Array
- static compute_prefill_length_from_mask(mask) Union[Array, ndarray, bool, number][source]#
Calculates the number of padding tokens at the beginning of each sequence from a 0/1 or boolean mask.
- config: EasyDeLBaseConfig#
- config_class: type[easydel.infra.base_config.EasyDeLBaseConfig]#
- create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None) TransformerCacheMetaData[source]#
Creates the metadata required for initializing a standard (non-paged) KV Cache.
This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.
- Parameters
batch_size (int) – The batch size for which the cache is being configured.
max_length (int) – The maximum sequence length the cache needs to support.
pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.
- Returns
An initialized metadata object for a standard KV cache.
- Return type
- create_paged_metadata(hbm_utilization: float, page_size: int, max_model_length: int) RaggedPagesCacheMetaData[source]#
Creates the static configuration metadata required for initializing a Paged KV Cache.
This method gathers necessary parameters from the model’s configuration (like number of layers, heads, dimensions) and combines them with the provided arguments to instantiate and return a RaggedPagesCacheMetaData object. This metadata object defines the structure and allocation parameters for the paged cache.
- Returns
- An initialized metadata object containing the
static configuration for the paged cache.
- Return type
- property esurge_compatible_model#
Returns a model instance compatible with eSurge inference engine.
eSurge requires models to use ragged page attention mechanisms (v2 or v3). If the current model uses a different attention mechanism, this property returns a new model instance with ragged_page_attention_v3 while preserving all parameters and state.
- Returns
Model instance with eSurge-compatible attention mechanism.
- Return type
Self
Note
If the model already uses a compatible attention mechanism, returns self. Otherwise, builds a new graph definition with ragged_page_attention_v3 and merges the existing parameters/state, leaving the original model unchanged.
Example
>>> # Get eSurge-compatible version of model >>> esurge_model = model.esurge_compatible_model >>> # Now safe to use with eSurge inference >>> outputs = esurge_model.esurge_generate("Hello world")
- esurge_generate(prompts: list[dict[str, str]] | list[str] | str, tools: list[dict] | None = None, sampling_params: SamplingParams | None = None, request_id: str | None = None, stream: bool = False, chat_template: str | None = None, *, tokenizer: str | PreTrainedTokenizerBase | None = None, max_model_len: int | None = None, min_input_pad: int | None = None, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, hbm_utilization: float | None = None, page_size: int | None = None, enable_prefix_caching: bool | None = None, runner_verbose: bool | None = None, decode_truncated_prompt: bool | None = None, destroy_pages_on_pause: bool | None = None, silent_mode: bool | None = None, use_tqdm: bool = False)[source]#
High-level interface for text generation using eSurge engine.
This method provides a convenient way to generate text using the eSurge inference engine with automatic caching and configuration management. It supports both chat and completion modes, with optional streaming.
All engine configuration parameters are optional. When omitted, the method will: 1. Try to retrieve values from a cached engine for this model 2. Fall back to sensible defaults if no cached engine exists 3. Only require tokenizer on the first call when no cache exists
- Parameters
prompts – Input prompts. Can be: - Single string for simple completion - List of strings for batch completion - List of dicts with ‘role’ and ‘content’ keys for chat mode
tools – Optional list of tool/function definitions for function calling in chat mode.
sampling_params – Generation parameters (temperature, top_p, max_tokens, etc.). Defaults to SamplingParams(max_tokens=128) if None.
request_id – Optional unique identifier for tracking. Auto-generated if None.
stream – If True, returns an iterator for streaming generation. If False, returns complete results.
chat_template – Optional custom Jinja2 template for chat formatting.
tokenizer – Tokenizer path or instance. Required only on first call if no cached engine exists. Subsequent calls can omit this to reuse the cached tokenizer.
max_model_len – Maximum sequence length. Defaults to model’s max position embeddings.
min_input_pad – Minimum padding for input sequences. Defaults to 16.
max_num_seqs – Maximum number of concurrent sequences. Defaults to 32.
max_num_batched_tokens – Maximum tokens per batch. Defaults to None (auto-calculate).
hbm_utilization – Fraction of HBM to use for KV cache. Defaults to 0.85.
page_size – Size of memory pages for paged attention. Defaults to 128.
enable_prefix_caching – Enable prefix caching for shared prompts. Defaults to True.
runner_verbose – Enable verbose logging in the model runner. Defaults to False.
decode_truncated_prompt – Decode and display truncated prompts. Defaults to True.
destroy_pages_on_pause – Free memory pages when requests are paused. Defaults to True.
- Returns
- If stream=True: Iterator[RequestOutput] with delta updates
If stream=False: RequestOutput with complete response
- For completion mode (prompts is str or list[str]):
If stream=True: Iterator[RequestOutput] with delta updates
If stream=False: list[RequestOutput] with complete responses
- Return type
For chat mode (prompts is list[dict])
Example
>>> # Simple completion >>> outputs = model.esurge_generate("Tell me about AI") >>> print(outputs[0].get_text()) >>> >>> # Streaming completion >>> for chunk in model.esurge_generate("Tell me a story", stream=True): ... print(chunk.delta_text, end="", flush=True) >>> >>> # Chat mode >>> messages = [ ... {"role": "system", "content": "You are helpful."}, ... {"role": "user", "content": "What is 2+2?"} ... ] >>> response = model.esurge_generate(messages) >>> print(response.get_text())
- property esurge_graphdef#
Returns a graph definition compatible with eSurge inference engine.
eSurge requires models to use ragged page attention mechanisms (v2 or v3). If the current model uses a different attention mechanism, this property creates a new graph definition with ragged_page_attention_v3.
- Returns
Graph definition with eSurge-compatible attention mechanism.
- Return type
nn.GraphDef
Note
This creates only the graph structure, not a complete model. Use esurge_compatible_model if you need a full model instance.
Example
>>> gdef = model.esurge_graphdef >>> # Use gdef for creating eSurge-compatible model instances
- generate(input_ids: Union[Array, ndarray, bool, number], generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: easydel.inference.logits_process.LogitsProcessorList | None = None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- get_esurge(tokenizer: str | PreTrainedTokenizerBase | None = None, max_model_len: int | None = None, min_input_pad: int | None = None, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, hbm_utilization: float | None = None, page_size: int | None = None, enable_prefix_caching: bool | None = None, runner_verbose: bool | None = None, decode_truncated_prompt: bool | None = None, destroy_pages_on_pause: bool | None = None, silent_mode: bool | None = None)[source]#
Gets or creates an eSurge engine with the specified parameters.
This method intelligently retrieves an existing cached engine or creates a new one. For any parameter that is None, it will: 1. Try to retrieve the value from a cached engine 2. If no cached engine exists, use sensible defaults 3. Only require tokenizer if no cached engine is available
- Parameters
tokenizer – Tokenizer path or instance. If None, retrieves from cached engine. Required only if no cached engine exists.
max_model_len – Maximum sequence length. Defaults to model’s max position embeddings.
min_input_pad – Minimum padding for input sequences. Defaults to 16.
max_num_seqs – Maximum number of concurrent sequences. Defaults to 32.
max_num_batched_tokens – Maximum tokens per batch. Defaults to None.
hbm_utilization – Fraction of HBM to use for KV cache. Defaults to 0.85.
page_size – Size of memory pages for paged attention. Defaults to 128.
enable_prefix_caching – Enable prefix caching. Defaults to True.
runner_verbose – Enable verbose logging. Defaults to False.
decode_truncated_prompt – Decode truncated prompts. Defaults to True.
destroy_pages_on_pause – Free memory on pause. Defaults to True.
- Returns
eSurge engine instance, either from cache or newly created.
- Raises
ValueError – If tokenizer is required but not provided and no cached engine exists.
Example
>>> # First call with tokenizer (creates new engine) >>> engine = model.get_esurge(tokenizer="meta-llama/Llama-2-7b-hf") >>> >>> # Subsequent calls without parameters (reuses cached engine) >>> engine = model.get_esurge() >>> >>> # Override specific parameters >>> engine = model.get_esurge(max_num_seqs=128)
- get_relevant_esurge(tokenizer: str | PreTrainedTokenizerBase | None = None, max_num_seqs: int | None = None)[source]#
Retrieves a relevant eSurge engine instance from the cache.
This method searches for an existing eSurge engine in the cache that matches the current model. If tokenizer or max_num_seqs are None, it returns the most recently created engine for this model. If no engine exists and parameters are missing, it uses sensible defaults.
- Parameters
tokenizer – Optional tokenizer path or instance. If None, retrieves from the most recent cached engine for this model.
max_num_seqs – Optional maximum number of concurrent sequences. If None, uses value from cached engine or defaults to 32 if no cache exists.
- Returns
eSurge engine instance if found in cache, None otherwise.
Example
>>> # Try to get existing engine with default params >>> engine = model.get_relevant_esurge() >>> if engine: ... outputs = engine.generate("Hello world") >>> >>> # Get engine with specific tokenizer >>> engine = model.get_relevant_esurge(tokenizer="gpt2")
- init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None) TransformerCache[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- init_ragged_pages(metadata: easydel.layers.caching.ragged_page.cache.RaggedPagesCacheMetaData | None = None, page_size: int | None = None, hbm_utilization: float | None = None, max_model_length: int | None = None) RaggedPagesCache[source]#
Initializes and returns the actual Paged Attention KV Cache tensors.
This method orchestrates the creation of the RaggedPagesCache. It either uses a pre-existing RaggedPagesCacheMetaData object passed via the metadata argument, or if metadata is None, it first creates the metadata by calling self.create_paged_metadata using the other provided arguments (page_size, batch_size, etc.).
Finally, it calls RaggedPagesCache.init_cache to allocate the necessary paged tensors (key_pages, value_pages for each layer) based on the metadata, model’s mesh, dtype, partition manager, and quantization settings.
- Parameters
metadata (tp.Optional[RaggedPagesCacheMetaData]) – An optional pre-configured metadata object. If provided, other arguments like page_size, batch_size etc., are ignored for metadata creation.
page_size (tp.Optional[int]) – Number of tokens per page. Required if metadata is None.
hbm_utilization (tp.Optional[float]) – Target HBM usage. Required if metadata is None.
- Returns
- An initialized RaggedPagesCache object containing the allocated
cache tensors (views) for all layers.
- Return type
- Raises
AssertionError – If metadata is None and any of the required arguments (page_size, batch_size, max_sequences, dtype, hbm_utilization) are also None.
- list_esurge_engines() list[dict][source]#
List all cached eSurge engines for this model.
Returns a list of dictionaries containing information about each cached engine, including its status (running/paused), number of requests, and configuration hash.
- Returns
cache_key: The cache key for this engine
paused: Whether the engine is paused
running_requests: Number of currently running requests
pending_requests: Number of pending requests
max_num_seqs: Maximum concurrent sequences
- Return type
List of dicts with engine information
Example
>>> engines = model.list_esurge_engines() >>> for engine in engines: ... print(f"Engine {engine['cache_key']}: " ... f"Paused={engine['paused']}, " ... f"Running={engine['running_requests']}")
- pause_esurge(engine_id: str | None = None) None[source]#
Pause eSurge engine(s) for this model.
Pauses the background scheduler of eSurge engines without clearing queued state. This is useful for temporarily freeing resources while keeping the engine ready for quick resumption.
- Parameters
engine_id – Optional specific engine cache key to pause. If None, pauses all engines associated with this model.
Example
>>> # Pause all engines for this model >>> model.pause_esurge() >>> >>> # Later, generate will auto-resume >>> outputs = model.esurge_generate("prompt") # Auto-resumes!
- prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: jax.Array | None = None, token_type_ids: jax.Array | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None) dict[str, Any][source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- resume_esurge(engine_id: str | None = None) None[source]#
Resume paused eSurge engine(s) for this model.
Resumes the background scheduler of paused eSurge engines, making them ready to process generation requests again.
- Parameters
engine_id – Optional specific engine cache key to resume. If None, resumes all engines associated with this model.
Example
>>> # Pause engines to free resources >>> model.pause_esurge() >>> >>> # Manually resume when needed >>> model.resume_esurge() >>> outputs = model.esurge_generate("prompt")
- update_inputs_for_generation(model_outputs, model_kwargs) dict[str, Any][source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.infra.mixins.generation.GreedyState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#
Bases:
objectState container for greedy search generation.
Tracks the current state during greedy decoding, where the most probable token is selected at each step.
- cur_len#
Current length of generated sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- sequences#
Generated token sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- running_token#
Currently processed token.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- is_sent_finished#
Boolean flags for finished sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- model_kwargs#
Additional model-specific arguments.
- Type
dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.mixins.generation.SampleState(cur_len: Union[Array, ndarray, bool, number], sequences: Union[Array, ndarray, bool, number], running_token: Union[Array, ndarray, bool, number], is_sent_finished: Union[Array, ndarray, bool, number], prng_key: Union[Array, ndarray, bool, number], model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]])[source]#
Bases:
objectState container for sampling-based generation.
Tracks the current state during sampling generation, where tokens are sampled from the probability distribution.
- cur_len#
Current length of generated sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- sequences#
Generated token sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- running_token#
Currently processed token.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- is_sent_finished#
Boolean flags for finished sequences.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- prng_key#
JAX PRNG key for random sampling.
- Type
Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]
- model_kwargs#
Additional model-specific arguments.
- Type
dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- model_kwargs: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.