Source code for easydel.__init__.inference.vsurge.engines.oengine.engine

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp
from functools import cached_property

import jax
from eformer import common_types
from eformer import escale as es
from flax import nnx as nn
from jax import numpy as jnp
from jax.sharding import NamedSharding as Ns
from jax.sharding import PartitionSpec as Ps
from transformers import ProcessorMixin

from easydel.layers.caching.paged_attention import (
	ActiveSequenceBatch,
	HBMPageManager,
	InferenceScheduler,
	ModelIOProcessor,
	ModelOutputBatch,
	PagedAttentionCache,
)
from easydel.layers.caching.paged_attention.types import NextIterationPlan
from easydel.utils.compiling_utils import cjit

from .._abstract_engine import AbstractInferenceEngine
from .utilities import execute_forward

if tp.TYPE_CHECKING:
	from easydel.infra import EasyDeLBaseModule
	from easydel.infra.utils import ProcessingClassType
else:
	EasyDeLBaseModule = tp.Any
	ProcessingClassType = tp.Any

GenerationState = tp.Any
ResultTokens = tp.Any

NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
MODE_PREFILL = common_types.MODE_PREFILL


[docs]class oEngine(AbstractInferenceEngine): """ Optimized inference engine for EasyDeL models using Paged Attention. This engine manages the inference process, including KV cache management with Paged Attention, scheduling, and execution of model forward passes for both prefill and decode steps. """ def __init__( self, model: EasyDeLBaseModule, processor: ProcessingClassType, storage: PagedAttentionCache, manager: HBMPageManager, max_concurrent_decodes: int | None = None, max_concurrent_prefill: int | None = None, prefill_lengths: int | None = None, max_prefill_length: int | None = None, max_length: int | None = None, batch_size: int | None = None, seed: int = 894, ): """ Initializes the oEngine. Args: model: The EasyDeL model to use for inference. processor: The processor (tokenizer/feature extractor) for the model. storage: The PagedAttentionCache instance for KV cache management. manager: The HBMPageManager instance for managing memory pages. max_concurrent_decodes: Maximum number of sequences to decode concurrently. max_concurrent_prefill: Maximum number of sequences to prefill concurrently. prefill_lengths: List of maximum prefill lengths to bucket by. max_prefill_length: Maximum allowed length for the initial prompt. max_length: Maximum total sequence length (prompt + generation). batch_size: The batch size for inference. seed: The random seed for PRNG key initialization. """ self.model = model self.storage = storage self.manager = manager self.scheduler = InferenceScheduler(manager=manager) self.metadata = manager.metadata # Split the model into graph definition, graph state, and other components self.graphdef, self.graphstate, self.graphothers = model.split_module() self.processor = processor # Initialize the quantizer based on model configuration self._quantizer = model._quant_class( quantization_method=model.config.kv_cache_quantization_method, block_size=model.config.kv_cache_quantization_blocksize, quantization_platform=model.config.platform, ) self._partition_manager = model.config.partition_manager self._max_prefill_lengths = prefill_lengths or [2**s for s in range(9, 24)] self._max_concurrent_decodes = max_concurrent_decodes or jax.device_count() self._max_concurrent_prefill = max_concurrent_prefill or jax.device_count() self._max_prefill_length = max_prefill_length or 4096 self._max_length = max_length or 8192 self._batch_size = batch_size self._prng_key = jax.random.PRNGKey(seed) self._empty_sharding = Ns(model.mesh, Ps()) self._setup_functions() @property def max_concurrent_prefill(self): return self._max_concurrent_prefill @property def max_concurrent_decodes(self): return self._max_concurrent_decodes
[docs] def get_state_shardings(self, is_decode: bool = False): """ Returns the sharding specifications for the engine's state. Args: is_decode: A boolean indicating if the sharding is for the decode state. Returns: A tuple representing the sharding specification. """ return (".*", Ps())
def _setup_functions(self): """ Sets up JIT-compiled functions for the inference engine. This method defines the sharding for model components and the KV cache and compiles the `execute_forward` function using JAX's `jit` and EasyDeL's `cjit` for efficient execution within the model's mesh. """ with self.model.mesh: # Extract sharding specifications for model state, other components, and storage self.graphstate_sharding = es.extract_shardings(self.graphstate) self.graphother_sharding = es.extract_shardings(self.graphothers) self.storage_sharding = es.extract_shardings(self.storage) # Compile the execute_forward function with specified shardings self.continuous_forward = cjit( jax.jit( execute_forward, static_argnums=(0,), # graphdef is static in_shardings=( self.graphstate_sharding, # Sharding for graphstate self.graphother_sharding, # Sharding for graphothers None, # Sharding for ModelIOProcessor input (handled internally) None, # Sharding for eos_token_ids (broadcast) self.storage_sharding, # Sharding for storage (KV cache) None, # Sharding for prng_key (handled internally) ), out_shardings=( None, self.storage_sharding, None, ), # Output shardings for (output, updated_storage, _) donate_argnums=(5,), # Donate the storage argument to avoid copying ), static_argnums=(0,), # graphdef is static for cjit as well ) @cached_property def pad_token_id(self): """ Returns the pad token ID from the processor. """ if isinstance(self.processor, ProcessorMixin): pad_token_id = self.processor.tokenizer.pad_token_id else: pad_token_id = self.processor.pad_token_id return pad_token_id @cached_property def eos_token_ids(self) -> list[int]: """ Returns a list of end-of-sequence token IDs from the processor and model config. """ eos_ids = [] if isinstance(self.processor, ProcessorMixin): proc_eos_token_id = self.processor.tokenizer.eos_token_id else: proc_eos_token_id = self.processor.eos_token_id if isinstance(proc_eos_token_id, int): proc_eos_token_id = [proc_eos_token_id] eos_ids = eos_ids + proc_eos_token_id if hasattr(self.model, "generation_config"): conf_eos = self.model.generation_config.eos_token_id if isinstance(conf_eos, int): conf_eos = [conf_eos] eos_ids = eos_ids + conf_eos return list(set(eos_ids)) @property def samples_per_slot(self) -> int: """Number of samples generated per inference slot. This determines how many independent generation results are produced for each logical slot managed by the engine. It's often 1, but could be higher for techniques like parallel sampling. """ return 1 @property def prng_key(self) -> jax.random.PRNGKey: """Provides a new PRNG key split from the internal state for sampling. Each call to this property consumes the current key and returns a new, unique key, ensuring that subsequent sampling operations use different randomness. Returns: A new JAX PRNGKey. """ self._prng_key, new_key = jax.random.split(self._prng_key, 2) return new_key @property def prefill_lengths(self) -> list[int]: """Returns the configured list of max prefill length buckets for the engine.""" return self._max_prefill_lengths @property def batch_size(self) -> int | None: """Returns the configured batch size for the engine, if specified.""" return self._batch_size @property def max_prefill_length(self) -> int: """Maximum allowed length for the initial prompt (prefill phase). Prompts longer than this will be truncated or handled according to the padding/truncation logic. """ return self._max_prefill_length @property def max_length(self) -> int: """Maximum total sequence length (prompt + generation). This defines the size of the KV cache allocated. """ return self._max_length
[docs] def get_prefix_destination_sharding(self) -> tp.Any: """ Returns the shardings necessary to transfer KV cache data between engines. This method is intended for scenarios involving multiple engines or devices where KV cache data needs to be moved. Returns: The sharding specification for prefix destinations. """ pass # Implementation needed
[docs] def init_decode_state(self, *args, **kwargs) -> ActiveSequenceBatch: """ Initializes the decode state for active sequences. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: An initialized ActiveSequenceBatch instance. """ return ActiveSequenceBatch.create(self.metadata, self.model.config.mesh)
[docs] def free_resource(self, slot: int) -> bool: """Frees resources associated with a specific inference slot. (Not Implemented) Args: slot: The index of the slot to free. Returns: Always returns False as it's not implemented. """ return False # Placeholder: Implementation needed
@property def colocated_cpus(self) -> tp.Union[list[jax.Device], None]: """Returns CPU devices colocated with the engine's accelerator devices. This information can be useful for optimizing data transfers between host (CPU) and accelerator (GPU/TPU) memory. Currently returns None as the implementation is pending. Returns: A list of colocated JAX CPU devices, or None if not implemented or available. """ return None # Placeholder: Implementation needed
[docs] def forward( self, graphstate: nn.GraphState, graphothers: nn.GraphState, state: ActiveSequenceBatch, iteration_plan: NextIterationPlan, ) -> ModelOutputBatch: """ Performs a forward pass of the model. This method executes the compiled `continuous_forward` function, processing the input state and iteration plan to produce model outputs and update the KV cache storage. Args: graphstate: The graph state of the model. graphothers: Other graph components of the model. state: The current active sequence batch state. iteration_plan: The plan for the current inference iteration. Returns: The output batch from the model. """ with self.model.mesh: # Execute the compiled forward pass output, self.storage, _ = self.continuous_forward( self.graphdef, graphstate, graphothers, ModelIOProcessor.build_input( # Build the input for the forward pass iteration_plan=iteration_plan, metadata=self.metadata, decodes_state=state, ), jnp.array(self.eos_token_ids).ravel(), # Provide EOS token IDs self.storage, # Pass the current KV cache storage self.prng_key, # Pass the PRNG key for sampling ) return output
[docs] def prefill( self, graphstate: nn.GraphState, graphothers: nn.GraphState, tokens: jax.Array, valids: jax.Array, true_length: int, temperature: jax.Array, top_p: jax.Array, top_k: jax.Array, rngs: jax.random.PRNGKey, ) -> tuple[GenerationState, ResultTokens]: """ Performs the prefill step for a batch of prompts. This involves processing the initial prompt tokens and populating the KV cache. Args: graphstate: The graph state of the model. graphothers: Other graph components of the model. tokens: The input tokens for the prompts. valids: A boolean array indicating valid tokens. true_length: The true length of the sequences. temperature: The temperature for sampling. top_p: The top-p value for sampling. top_k: The top-k value for sampling. rngs: The PRNG key for sampling. Returns: A tuple containing the generation state after prefill and the result tokens. """ pass
[docs] def decode( self, graphstate: nn.GraphState, graphothers: nn.GraphState, state: GenerationState, rngs: jax.random.PRNGKey, ) -> tuple[GenerationState, ResultTokens]: """ Performs a single decode step for active sequences. This involves generating the next token for each sequence based on the current state and KV cache. Args: graphstate: The graph state of the model. graphothers: Other graph components of the model. state: The current generation state. rngs: The PRNG key for sampling. Returns: A tuple containing the updated generation state and the generated result tokens. """ pass
[docs] def insert( self, prefix: GenerationState, decode_state: GenerationState, slot: int, ) -> GenerationState: """ Inserts or updates a generation state for a specific slot. This is typically used to integrate the results of a prefill step into the ongoing decode process for a particular sequence slot. Args: prefix: The generation state from the prefill step. decode_state: The current decode state. slot: The slot index to insert into. Returns: The updated decode state. """ pass
[docs] def bulk_insert( self, prefix: GenerationState, decode_state: GenerationState, slots: list[int], ) -> GenerationState: """ Efficiently inserts multiple prefill results into the decode state. This method is optimized for integrating the results of a batch prefill operation into the decode state for multiple sequence slots. Args: prefix: The generation state from the bulk prefill step. decode_state: The current decode state. slots: A list of slot indices to insert into. Returns: The updated decode state. """ pass