Source code for easydel.__init__.inference.vinference.utilities

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import typing as tp

import jax
import jax.random
from eformer.escale import PartitionAxis
from eformer.pytree import auto_pytree
from flax import nnx as nn
from jax import numpy as jnp
from jax import random, sharding
from jax.sharding import PartitionSpec

from easydel.utils.compiling_utils import get_safe_hash_int

from ..logits_process import LogitsProcessorList, hash_fn
from ..utilities import SamplingParams

AmapType = tp.Mapping[str, tp.Dict[str, tp.Any]]
PropType = tp.Optional[tp.Union[AmapType, tp.List[AmapType]]]


[docs]@auto_pytree class vInferencePreCompileConfig: """ Configuration class for pre-compiling vInference functions. This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios. Attributes: batch_size: Batch size or list of batch sizes for text generation. prefill_length: Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context. vision_included: Whether vision inputs are included in the model. vision_batch_size: Batch size for vision inputs. Only relevant if `vision_included` is True. vision_channels: Number of channels for vision inputs. Only relevant if `vision_included` is True. vision_height: Height of vision inputs. Only relevant if `vision_included` is True. vision_width: Width of vision inputs. Only relevant if `vision_included` is True. required_props: Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments). """ batch_size: tp.Union[int, tp.List[int]] = 1 prefill_length: tp.Optional[tp.Union[int, tp.List[int]]] = None vision_included: tp.Union[bool, tp.List[bool]] = False vision_batch_size: tp.Optional[tp.Union[int, tp.List[int]]] = None vision_channels: tp.Optional[tp.Union[int, tp.List[int]]] = None vision_height: tp.Optional[tp.Union[int, tp.List[int]]] = None vision_width: tp.Optional[tp.Union[int, tp.List[int]]] = None required_props: PropType = None def _im_standalone(self) -> bool: """ Checks if the configuration represents a single, standalone compilation scenario. This method iterates through the configuration fields. If any field's value is a list, it indicates multiple scenarios, and the method returns False. Otherwise, it represents a single scenario, and it returns True. Returns: True if the configuration defines a single scenario, False otherwise. """ standalone = True for rfield in dataclasses.fields(self): attr = getattr(self, rfield.name) if isinstance(attr, list): standalone = False return standalone _is_standalone = _im_standalone
[docs] def extract(self) -> dict: """ Converts the configuration instance into a dictionary. This method is useful for serialization or easily accessing all configuration values. Returns: A dictionary representation of the `vInferencePreCompileConfig` instance. """ return dataclasses.asdict(self)
[docs] def get_default_hash(self) -> int: """ Generates a unique integer hash representing the configuration. This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration. Returns: An integer hash value representing the configuration. """ hash_str = "" hash_str += str(self.batch_size) + "-" hash_str += str(self.prefill_length) + "-" hash_str += str(self.vision_included) + "-" hash_str += str(self.vision_batch_size) + "-" hash_str += str(self.vision_channels) + "-" hash_str += str(self.vision_height) + "-" hash_str += str(self.vision_width) + "-" hash_str += str(self.required_props) hash_out = get_safe_hash_int(hash_str) return hash_out
__hash__ = get_default_hash
[docs] def get_standalones(self) -> tp.List["vInferencePreCompileConfig"]: """ Generates a list of standalone configurations from a potentially multi-value config. If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual `vInferencePreCompileConfig` instances. Each resulting instance represents a single, specific compilation scenario. If an attribute's list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes. If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance. Returns: A list of `vInferencePreCompileConfig` instances, each representing a single, standalone compilation scenario. """ if self._is_standalone(): return [self] list_fields = {} max_length = 0 for rfield in dataclasses.fields(self): attr = getattr(self, rfield.name) if isinstance(attr, list): list_fields[rfield.name] = attr max_length = max(max_length, len(attr)) # Create standalone configs standalone_configs = [] for i in range(max_length): config_kwargs = {} for rfield in dataclasses.fields(self): attr = getattr(self, rfield.name) field_name = rfield.name if field_name in list_fields: list_attr = list_fields[field_name] # Use value at index i if available, otherwise use the last value if i < len(list_attr): config_kwargs[field_name] = list_attr[i] else: config_kwargs[field_name] = list_attr[-1] else: # For non-list fields, use the original value config_kwargs[field_name] = attr standalone_configs.append(vInferencePreCompileConfig(**config_kwargs)) return standalone_configs
vInferencePreCompileConfig.__hash__ = vInferencePreCompileConfig.get_default_hash @auto_pytree class SampleState: """ Represents the state of the sampling process during token generation within the vInference engine. This class encapsulates all necessary information to pause and resume the generation loop. It tracks the progress of generation, including the tokens generated so far, the current position, completion status, random number generator state, and any model-specific state (like attention caches). Attributes: current_length: The current length of the generated sequences (number of tokens generated so far). sequences: The tensor holding the generated token IDs for each sequence in the batch. Shape: `(batch_size, max_sequence_length)`. running_token: The most recently generated token for each sequence. Used as input for the next step. Shape: `(batch_size, 1)`. is_sequence_finished: A boolean tensor indicating whether each sequence in the batch has reached an end-of-sequence (EOS) token or the maximum generation length. Shape: `(batch_size,)`. prng_key: The JAX pseudo-random number generator key used for stochastic sampling. model_kwargs: A dictionary containing any additional arguments required by the model for the next generation step (e.g., attention cache/past key-values). The structure depends on the specific model implementation. generate_func_flops: Estimated Floating Point Operations (FLOPs) consumed by the main generation function (often the transformer forward pass). Defaults to -inf. interval_func_flops: Estimated FLOPs for any interval-specific function executed during generation (if applicable). Defaults to -inf. tokens_per_second: Estimated generation speed in tokens per second. Defaults to -inf. generated_tokens: The total count of tokens generated across all sequences in the current generation process up to this state. Defaults to 0. padded_length: The target length to which sequences are padded. This might be different from `max_sequence_length` in some scenarios. Defaults to 0. _time_spent_computing: Internal tracker for the cumulative computation time spent to reach this state. Defaults to 0.0. _compile_config: The `vInferencePreCompileConfig` instance used for pre-compiling the functions associated with this generation state. Defaults to None. """ current_length: tp.Union[jax.Array, sharding.NamedSharding] sequences: tp.Union[jax.Array, sharding.NamedSharding] running_token: tp.Union[jax.Array, sharding.NamedSharding] is_sequence_finished: tp.Union[jax.Array, sharding.NamedSharding] prng_key: tp.Union[random.PRNGKey, sharding.NamedSharding] model_kwargs: tp.Union[tp.Dict[str, jax.Array], sharding.NamedSharding] # vInference Ops generate_func_flops: tp.Optional[float] = float("-inf") interval_func_flops: tp.Optional[float] = float("-inf") tokens_per_second: tp.Optional[float] = float("-inf") generated_tokens: tp.Optional[int] = 0 padded_length: tp.Optional[int] = 0 _time_spent_computing: tp.Optional[float] = 0.0 _compile_config: tp.Optional[vInferencePreCompileConfig] = None def create_sampling_step( logits_processor: LogitsProcessorList, logits_warper: LogitsProcessorList, eos_token_id: jax.Array, pad_token_id: jax.Array, ) -> tp.Callable: """ Creates a callable function that performs a single step of token generation (sampling). This factory function returns a `sampling_step` function tailored with the provided logit processors/warpers and token IDs. The returned function is designed to be used within a generation loop (e.g., `jax.lax.scan`). Args: logits_processor: A `LogitsProcessorList` containing functions to modify logits deterministically before sampling (e.g., applying temperature, filtering banned tokens). Applied sequentially. logits_warper: A `LogitsProcessorList` containing functions to modify the probability distribution derived from logits (e.g., top-k, top-p/nucleus sampling). Applied sequentially after `logits_processor`. eos_token_id: A JAX array containing the token ID(s) representing the end-of-sequence. Generation stops for a sequence once an EOS token is sampled. pad_token_id: The JAX array representing the padding token ID. Once a sequence is finished (EOS sampled), subsequent steps will generate this token. Returns: A callable function `sampling_step(graphdef, graphstate, graphother, state)` which takes the model's NNX graph components (`graphdef`, `graphstate`, `graphother`) and the current `SampleState`, performs one generation step, and returns the updated `SampleState`. """ def sampling_step(graphdef, graphstate, graphother, state: SampleState): """ Performs a single sampling step using the provided model components and state. Args: graphdef: The definition part of the NNX graph model. graphstate: The state part of the NNX graph model. graphother: Other components of the NNX graph model. state (SampleState): The current generation state. Returns: SampleState: The updated generation state after one sampling step. """ model = nn.merge(graphdef, graphstate, graphother) model_outputs = model( input_ids=state.running_token, return_dict=True, **state.model_kwargs, ) logits = model_outputs.logits[:, -1] if logits_processor is not None: logits = logits_processor(state.sequences, logits, state.current_length) if logits_warper is not None: logits = logits_warper(logits, logits, state.current_length) next_token = jax.random.categorical(state.prng_key, logits, axis=-1) # Ensure finished sequences continue generating pad tokens next_token = ( next_token * ~state.is_sequence_finished + pad_token_id * state.is_sequence_finished ) # Determine if the sequence is finished (EOS encountered or already finished) next_sequence_finished = state.is_sequence_finished | jnp.isin( next_token, eos_token_id, ) next_token = next_token[:, None] # Add dimension for dynamic update # Update sequences with the new token next_sequences = jax.lax.dynamic_update_slice( state.sequences, next_token, (0, state.current_length), ) # Update model-specific arguments (like past KVs) next_model_kwargs = model.update_inputs_for_generation( model_outputs, state.model_kwargs ) # Return the updated state return state.replace( current_length=state.current_length + 1, sequences=next_sequences, running_token=next_token, is_sequence_finished=next_sequence_finished, prng_key=jax.random.split(state.prng_key, 2)[0], # Update PRNG key model_kwargs=next_model_kwargs, generated_tokens=state.generated_tokens + state.sequences.shape[0], ) return sampling_step
[docs]@auto_pytree class vInferenceConfig: """ Configuration class for the vInference engine, controlling the overall generation process. This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings. Attributes: max_new_tokens: The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64. streaming_chunks: The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16. num_return_sequences: The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from `vInferencePreCompileConfig`) to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1. pad_token_id: The token ID used for padding sequences. If None, the model's default pad token ID might be used, or padding might not be applied. bos_token_id: The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic. eos_token_id: The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers. partition_rules: A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model's parameters and intermediate states (like attention cache). If None, default rules based on `partition_axis` are generated. Example: `((".*kernel.*", PartitionSpec("fsdp", None)), ...)` partition_axis: A `PartitionAxis` object defining the logical names for sharding axes (e.g., 'batch', 'sequence', 'head'). Required if `partition_rules` is None, used to generate default sharding rules. _loop_rows: (Internal) The calculated number of iterations needed in the generation loop based on `max_new_tokens` and `streaming_chunks`. Automatically computed in `__post_init__`. sampling_params: A `SamplingParams` object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default `SamplingParams` instance with `max_tokens` set to `max_new_tokens` is created in `__post_init__`. """ max_new_tokens: int = 64 streaming_chunks: int = 16 num_return_sequences: tp.Optional[tp.Union[int, tp.Dict[int, int]]] = 1 pad_token_id: tp.Optional[int] = None bos_token_id: tp.Optional[int] = None eos_token_id: tp.Optional[tp.Union[int, tp.List[int]]] = None partition_rules: tp.Optional[tp.Tuple[tp.Tuple[str, tp.Any]]] = None partition_axis: tp.Optional[PartitionAxis] = None _loop_rows: tp.Optional[int] = None sampling_params: tp.Optional[SamplingParams] = None
[docs] def get_partition_rules( self, # in case that someone needs to customize this runtime_config: tp.Optional[vInferencePreCompileConfig] = None, ) -> tp.Tuple[tp.Tuple[str, tp.Any], ...]: """ Generates or retrieves the sharding partition rules for the vInference engine. If `self.partition_rules` is already set (custom rules provided), it returns them directly. Otherwise, it constructs a default set of partition rules based on the axis names defined in `self.partition_axis`. These default rules aim to provide sensible sharding for common model components: - Input sequences (`sequences`, `running_token`) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats (8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes. - Any parameters/states not matching the specific rules are replicated by default (`.*`). Args: runtime_config: An optional `vInferencePreCompileConfig`. Currently unused in the default rule generation but available for potential customization in subclasses or future versions. Returns: A tuple of partition rules. Each rule is a tuple containing: - A regex pattern (string) matching parameter or state names. - A `jax.sharding.PartitionSpec` defining how the matched items should be sharded. Raises: AssertionError: If `self.partition_rules` is None and `self.partition_axis` is also None, as axis names are required to generate default rules. """ if self.partition_rules is not None: return self.partition_rules assert self.partition_axis is not None, ( "partition axis is required for state sharding if partition_rules is not provided" ) paxis = self.partition_axis # Key/Value ProjectionSharding Spec kvps = PartitionSpec( paxis.batch_axis, paxis.key_sequence_axis, paxis.head_axis, paxis.attention_dim_axis, ) # Input ID ProjectionSharding Spec idps = PartitionSpec(paxis.batch_axis, paxis.sequence_axis) # Default Rules return ( # Input sequences and running token ("(sequences|running_token)", idps), # Attention mask and position IDs ("model_kwargs/(attention_mask|position_ids)", idps), # Past Key/Value states (8-bit quantization) ("model_kwargs/past_key_values/views/[0-9]+/(key|value)/(scale|weight)", kvps), # Past Key/Value states (NF4 quantization) ("model_kwargs/past_key_values/views/[0-9]+/(key|value)/(packed|absmax)", kvps), # Past Key/Value states (standard float types) ("model_kwargs/past_key_values/views/[0-9]+/(key|value)", kvps), # Default rule for anything else (replicated) (".*", PartitionSpec()), )
def __post_init__(self): """ Performs initialization tasks after the dataclass is created. Specifically, it: 1. Calculates `_loop_rows`: Determines the number of iterations required for the generation loop based on `max_new_tokens` and `streaming_chunks`. 2. Initializes `sampling_params`: If `sampling_params` was not provided during instantiation, it creates a default `SamplingParams` instance, setting its `max_tokens` attribute to the value of `self.max_new_tokens`. """ if isinstance(self.max_new_tokens, int): # Calculate the number of streaming loops needed self._loop_rows = ( self.max_new_tokens + self.streaming_chunks - 1 ) // self.streaming_chunks if self.sampling_params is None: # Initialize default sampling parameters if none provided self.sampling_params = SamplingParams(max_tokens=self.max_new_tokens) __hash__ = hash_fn