Source code for easydel.__init__.inference.vsurge.engines.vengine.engine

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp
from functools import cached_property

import jax
from eformer import common_types
from eformer import escale as es
from flax import nnx as nn
from jax import numpy as jnp
from jax.sharding import NamedSharding as Ns
from jax.sharding import PartitionSpec as Ps
from transformers import ProcessorMixin

from easydel.utils.compiling_utils import cjit

from .._abstract_engine import AbstractInferenceEngine
from .utilities import (
	GenerationState,
	ResultTokens,
	continuous_bulk_insert,
	continuous_decode,
	continuous_insert,
	continuous_prefill,
)

if tp.TYPE_CHECKING:
	from easydel.infra import EasyDeLBaseModule
	from easydel.infra.utils import ProcessingClassType
else:
	EasyDeLBaseModule = tp.Any
	ProcessingClassType = tp.Any


NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
MODE_PREFILL = common_types.MODE_PREFILL


[docs]class vEngine(AbstractInferenceEngine): """ Core inference engine for EasyDeL models using NNX graphs. This engine manages the model state (split into graph definition, state, and other parameters) and provides JIT-compiled functions for the prefill and decode steps of autoregressive generation. It handles KV caching and sampling. """ def __init__( self, model: EasyDeLBaseModule, processor: ProcessingClassType, max_concurrent_decodes: int | None = None, max_concurrent_prefill: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, batch_size: int | None = None, seed: int = 894, ): """Initializes the vEngine. Args: model: The EasyDeLBaseModule (NNX) to use for inference. processor: The tokenizer/processor associated with the model. max_concurrent_decodes: The maximum number of sequences that can be decoded concurrently. Defaults to the number of available JAX devices. max_concurrent_prefill: The maximum number of prefill requests that can be processed concurrently. Defaults to the number of available JAX devices. prefill_lengths: A list of integer bucket lengths to choose from for prefill. Defaults to `DEFAULT_PREFILL_BUCKETS`. max_prefill_length: The maximum allowed length for the initial prompt (prefill phase). Defaults to 4096. max_length: The maximum total sequence length (prompt + generation). Defaults to 8192. batch_size: The batch size for the engine. Defaults to None. seed: The random seed for initializing the PRNG key used in sampling. Defaults to 894. """ from ...utils import DEFAULT_PREFILL_BUCKETS self.model = model self.graphdef, self.graphstate, self.graphothers = model.split_module() self.processor = processor self._quantizer = model._quant_class( quantization_method=model.config.kv_cache_quantization_method, block_size=model.config.kv_cache_quantization_blocksize, quantization_platform=model.config.platform, ) self._partition_manager = model.config.partition_manager self._max_prefill_lengths = prefill_lengths or DEFAULT_PREFILL_BUCKETS self._max_concurrent_decodes = max_concurrent_decodes or jax.device_count() self._max_concurrent_prefill = max_concurrent_prefill or jax.device_count() self._max_prefill_length = max_prefill_length or 4096 self._max_length = max_length or 8192 self._batch_size = batch_size self._prng_key = jax.random.PRNGKey(seed) self._empty_sharding = Ns(model.mesh, Ps()) self._setup_functions()
[docs] def get_state_shardings(self, is_decode: bool = False): pmang = self.model.config.partition_manager kvps = pmang.resolve( axes=[BATCH if is_decode else "_", KV_LENGTH, KV_HEAD, KV_HEAD_DIM], mode=MODE_PREFILL, ) bsharding = pmang.resolve(axes=[BATCH if is_decode else "_"], mode=MODE_PREFILL) return ( ("index", bsharding), ("logits", bsharding), ("tokens", bsharding), ("valids", bsharding), ("next_position_ids", bsharding), ("temperature", bsharding), ("top_p", bsharding), ("generated_tokens", bsharding), ("cache/views/[0-9]+/(key|value)/(scale|weight)", kvps), ("cache/views/[0-9]+/(key|value)/(packed|absmax)", kvps), ("cache/views/[0-9]+/(key|value)", kvps), ("cache/views/[0-9]+/index", bsharding), ("cache/views/[0-9]+/starts", bsharding), (".*", Ps()), )
def _setup_functions(self): with self.model.mesh: self._prefill_state_sharding = jax.tree_util.tree_map( lambda x: Ns(self.model.mesh, x), es.match_partition_rules( self.get_state_shardings(False), jax.eval_shape(self.init_decode_state), min_size=0, ), ) self._decodes_state_sharding = jax.tree_util.tree_map( lambda x: Ns(self.model.mesh, x), es.match_partition_rules( self.get_state_shardings(True), jax.eval_shape(self.init_decode_state), min_size=0, ), ) self.continuous_prefill = cjit( jax.jit( continuous_prefill, static_argnums=(0, 9), in_shardings=( es.extract_shardings(self.graphstate), es.extract_shardings(self.graphothers), self._empty_sharding, self._empty_sharding, None, None, self._empty_sharding, self._empty_sharding, None, None, ), out_shardings=(self._prefill_state_sharding, None), ), static_argnums=(0, 9), ) self.continuous_decode = cjit( jax.jit( continuous_decode, static_argnums=(0,), donate_argnums=(3,), in_shardings=( es.extract_shardings(self.graphstate), es.extract_shardings(self.graphothers), self._decodes_state_sharding, None, None, ), out_shardings=(self._decodes_state_sharding, None), ), static_argnums=(0,), ) self.continuous_insert = cjit( jax.jit( continuous_insert, donate_argnums=(0, 1), static_argnums=(3, 4), in_shardings=( self._prefill_state_sharding, self._decodes_state_sharding, None, ), out_shardings=self._decodes_state_sharding, ), static_argnums=(3, 4), ) @cached_property def pad_token_id(self): if isinstance(self.processor, ProcessorMixin): pad_token_id = self.processor.tokenizer.pad_token_id else: pad_token_id = self.processor.pad_token_id return pad_token_id @cached_property def eos_token_ids(self) -> list[int]: eos_ids = [] if isinstance(self.processor, ProcessorMixin): proc_eos_token_id = self.processor.tokenizer.eos_token_id else: proc_eos_token_id = self.processor.eos_token_id if isinstance(proc_eos_token_id, int): proc_eos_token_id = [proc_eos_token_id] eos_ids = eos_ids + proc_eos_token_id if hasattr(self.model, "generation_config"): conf_eos = self.model.generation_config.eos_token_id if isinstance(conf_eos, int): conf_eos = [conf_eos] eos_ids = eos_ids + conf_eos return list(set(eos_ids)) @property def samples_per_slot(self) -> int: """Number of samples generated per inference slot. This determines how many independent generation results are produced for each logical slot managed by the engine. It's often 1, but could be higher for techniques like parallel sampling. """ return 1 # self._max_concurrent_decodes // self._max_concurrent_prefill @property def prng_key(self) -> jax.random.PRNGKey: """Provides a new PRNG key split from the internal state for sampling. Each call to this property consumes the current key and returns a new, unique key, ensuring that subsequent sampling operations use different randomness. Returns: A new JAX PRNGKey. """ self._prng_key, new_key = jax.random.split(self._prng_key, 2) return new_key @property def prefill_lengths(self) -> list[int]: """Returns the configured list of max prefill length buckets for the engine.""" return self._max_prefill_lengths @property def batch_size(self) -> int | None: """Returns the configured batch size for the engine, if specified.""" return self._batch_size @property def max_concurrent_decodes(self) -> int: """Maximum number of sequences that can be decoded concurrently. This determines the batch size used during the decode phase. """ return self._max_concurrent_decodes @property def max_prefill_length(self) -> int: """Maximum allowed length for the initial prompt (prefill phase). Prompts longer than this will be truncated or handled according to the padding/truncation logic. """ return self._max_prefill_length @property def max_length(self) -> int: """Maximum total sequence length (prompt + generation). This defines the size of the KV cache allocated. """ return self._max_length
[docs] def get_prefix_destination_sharding(self) -> tp.Any: """Returns the shardings necessary to transfer KV cache data between engines. Currently returns None, indicating default or no specific sharding. """ return self._prefill_state_sharding
[docs] def init_decode_state(self, *args, **kwargs) -> GenerationState: """Initializes the GenerationState for a new sequence""" concurrent_decode = self.max_concurrent_decodes pmang = self.model.config.partition_manager sharding = Ns(self.model.mesh, pmang.resolve(axes=[BATCH], mode=MODE_PREFILL)) vocab_size = getattr(self.model.config, "vocab_size", None) if vocab_size is None and hasattr(self.model.config, "text_config"): vocab_size = getattr(self.model.config.text_config, "vocab_size", None) assert vocab_size is not None, "couldn't find `vocab_size` in model config" dt = self.model.dtype with self.model.mesh: return GenerationState( logits=jnp.zeros((concurrent_decode, vocab_size), dt, device=sharding), cache=self.model.init_cache(concurrent_decode, self.max_length), index=jnp.zeros((concurrent_decode, 1), "i4", device=sharding), tokens=jnp.zeros((concurrent_decode, 1), "i4", device=sharding), valids=jnp.zeros((concurrent_decode, self.max_length), "b1", device=sharding), next_position_ids=jnp.zeros((concurrent_decode, 1), "i4", device=sharding), temperature=jnp.zeros((concurrent_decode,), "f4", device=sharding), top_p=jnp.zeros((concurrent_decode,), "f4", device=sharding), generated_tokens=jnp.zeros((concurrent_decode, 1), "i4", device=sharding), )
[docs] def free_resource(self, slot: int) -> bool: """Frees resources associated with a specific inference slot. (Not Implemented) Args: slot: The index of the slot to free. Returns: Always returns False as it's not implemented. """ return False # Placeholder: Implementation needed
@property def colocated_cpus(self) -> tp.Union[list[jax.Device], None]: """Returns CPU devices colocated with the engine's accelerator devices. This information can be useful for optimizing data transfers between host (CPU) and accelerator (GPU/TPU) memory. Currently returns None as the implementation is pending. Returns: A list of colocated JAX CPU devices, or None if not implemented or available. """ return None # Placeholder: Implementation needed
[docs] def prefill( self, graphstate: nn.GraphState, graphothers: nn.GraphState, tokens: jax.Array, valids: jax.Array, true_length: int, temperature: jax.Array, top_p: jax.Array, rngs: jax.random.PRNGKey, ) -> tuple[GenerationState, ResultTokens]: """Performs the prefill step for initializing the generation process. Processes the initial prompt tokens, initializes the KV cache, and generates the *first* token of the sequence. This function is JIT-compiled. Args: graphstate: The NNX GraphState (parameters) of the model. graphothers: Other NNX state variables of the model. tokens: The input prompt token IDs (batch_size, sequence_length). valids: A boolean array indicating valid token positions in the input (batch_size, sequence_length or batch_size, max_length). rngs: A JAX PRNG key for sampling the first token. Returns: A tuple containing: - generation_state: The initial GenerationState for the decode loop. - result: A ResultTokens object containing the *first* generated token. """ return self.continuous_prefill( self.graphdef, graphstate, graphothers, tokens, valids, true_length, self.pad_token_id, temperature, top_p, self.max_length, self.samples_per_slot, rngs, )
[docs] def decode( self, graphstate: nn.GraphState, graphothers: nn.GraphState, state: GenerationState, rngs: jax.random.PRNGKey, ) -> tuple[GenerationState, ResultTokens]: """Performs a single decode step in the autoregressive generation loop. Takes the previous GenerationState, generates the next token using the model and KV cache, and updates the state. This function is JIT-compiled and allows the input state's cache to be modified in-place (donated). Args: graphstate: The NNX GraphState (parameters) of the model. graphothers: Other NNX state variables of the model. state: The current GenerationState from the previous step. `state.cache` is marked for donation. rngs: A JAX PRNG key for sampling the next token. Returns: A tuple containing: - next_generation_state: The updated GenerationState for the next iteration. - result: A ResultTokens object containing the newly generated token. """ return self.continuous_decode( self.graphdef, graphstate, graphothers, state, self.samples_per_slot, rngs, )
[docs] def insert( self, prefix: GenerationState, decode_state: GenerationState, slot: int, ) -> GenerationState: """Inserts or updates a generation state, potentially for managing batches. (JIT-compiled) This function seems designed to merge or update parts of the generation state, possibly inserting a 'prefix' state (e.g., from a completed prefill) into a larger batch state ('decode_state') at a specific 'slot'. The exact mechanism for insertion isn't fully clear from the current implementation, as it primarily focuses on broadcasting the prefix cache and returning the prefix state. Both input states' caches are donated. Args: prefix: The GenerationState to potentially insert (e.g., from prefill). Its cache is marked for donation. decode_state: The target GenerationState to update (e.g., the main decode loop state). Its cache is marked for donation. slot: The index within the batch where the insertion/update should occur. Returns: An updated GenerationState. In the current implementation, it returns the prefix state with its cache potentially broadcasted. Needs clarification on the intended merging logic with `decode_state` and `slot`. """ with self.model.mesh: return self.continuous_insert( prefix, decode_state, slot, self._quantizer, self._partition_manager, )
[docs] def bulk_insert( self, prefix: GenerationState, decode_state: GenerationState, slots: list[int], ) -> GenerationState: """Efficiently inserts multiple prefill results into the decode state. This function takes a `GenerationState` (`prefix`) typically resulting from a batch prefill operation and inserts its relevant components (logits, cache, index, tokens, valids, position IDs, generated tokens) into the main `decode_state` at multiple specified `slots`. This is useful for initializing the decode state after processing a batch of prompts simultaneously. Both input states' caches are donated. Args: prefix: The `GenerationState` containing the results from a prefill operation (or similar initialization). Its cache is marked for donation. decode_state: The target `GenerationState` (e.g., the main decode loop state) to be updated. Its cache is marked for donation. slots: A list of integer indices indicating the slots within the `decode_state`'s batch dimension where the corresponding data from the `prefix` state should be inserted. Returns: An updated `GenerationState` (`decode_state`) with the prefill results inserted at the specified slots. """ with self.model.mesh: return continuous_bulk_insert( prefix=prefix, decode_state=decode_state, slots=slots, quantizer=self._quantizer, partition_manager=self._partition_manager, )