Source code for easydel.infra.base_module

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base module implementation for EasyDeL models.

This module provides the core foundation for all EasyDeL neural network models,
implementing essential functionality for model initialization, parameter management,
sharding, quantization, and integration with the broader EasyDeL ecosystem.

The EasyDeLBaseModule class serves as the base class that all EasyDeL models inherit from,
providing:
- Parameter management and state handling
- Model sharding and gathering for distributed training
- Quantization and LoRA support
- Loss computation framework
- Integration with HuggingFace models
- Generation capabilities through mixins

Key Classes:
    EasyDeLBaseModule: The base class for all EasyDeL models, providing common
        functionality for parameter handling, sharding, and model operations.
    ParameterTransformRule: Data class defining rules for transforming parameter
        names and tensors, particularly useful for MoE models.

Example:
    >>> from easydel.infra import EasyDeLBaseModule, EasyDeLBaseConfig
    >>> import flax.nnx as nn
    >>>
    >>> class MyModel(EasyDeLBaseModule):
    ...     def __init__(self, config, dtype, param_dtype, precision, rngs):
    ...         super().__init__(config, dtype, param_dtype, precision, rngs)
    ...         # Initialize model layers
    ...         self.layer = nn.Linear(config.hidden_size, config.hidden_size)
    ...
    ...     def __call__(self, inputs):
    ...         return self.layer(inputs)
    >>>
    >>> # Create and use the model
    >>> config = EasyDeLBaseConfig(hidden_size=768)
    >>> model = MyModel(
    ...     config=config,
    ...     dtype=jnp.float32,
    ...     param_dtype=jnp.float32,
    ...     precision='highest',
    ...     rngs=nn.Rngs(0)
    ... )

The module integrates with JAX's sharding system for distributed training,
supports various quantization methods, and provides utilities for converting
between EasyDeL and HuggingFace model formats.
"""

from __future__ import annotations

import hashlib
import re
import typing as tp
import warnings
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property, partial
from re import Pattern
from typing import Self, Unpack

import chex
import flax
import flax.nnx
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from eformer.loggings import get_logger
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from easydel.infra.utils import ArrayParam
from easydel.utils import traversals
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict

from .base_config import EasyDeLBaseConfig, EasyDeLBaseConfigDict
from .etils import EasyDeLGradientCheckPointers
from .loss_utils import LOSS_MAPPING, ForCausalLMLoss, ForSequenceClassificationLoss, LossConfig, LossMetrics
from .mixins import BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

if tp.TYPE_CHECKING:
    from easydel.infra.base_state import EasyDeLState
    from easydel.layers.linear import ParallelLinear
    from easydel.layers.quantization import EasyDeLQuantizationConfig


PartitionLike = tp.Mapping[str, tp.Callable] | tp.Mapping[tuple, tp.Callable] | None


logger = get_logger(__name__)

BaseConf = EasyDeLBaseConfig


[docs]@dataclass class ParameterTransformRule: """Rule for transforming MoE parameter names and tensors. This dataclass defines transformation rules that can be applied to parameter names and their associated tensor values during model conversion or loading. It's particularly useful for handling Mixture of Experts (MoE) models where parameter naming conventions may differ between frameworks. Attributes: pattern: Regular expression pattern or string to match parameter names. Can be a compiled Pattern object or a string that will be used for matching. replacement: String to replace matched patterns in parameter names. Supports regex replacement syntax (e.g., r'\1' for capture groups). tensor_transform: Optional callable to transform the tensor values. Should take a tensor and return a transformed tensor. If None, no transformation is applied to the tensor values. consolidate_experts: Whether to consolidate multiple expert parameters into a single tensor. Useful for MoE models where experts may be stored separately but need to be combined. Example: >>> # Rule to rename expert layers >>> rule = ParameterTransformRule( ... pattern=r"expert_(\\d+)\\.weight", ... replacement=r"experts.\1.weight", ... tensor_transform=lambda x: x.transpose(), ... consolidate_experts=True ... ) """ pattern: str | Pattern replacement: str tensor_transform: Callable | None = None consolidate_experts: bool = False
[docs]class EasyDeLBaseModule(nn.Module, EasyBridgeMixin, EasyGenerationMixin, BaseModuleProtocol): """ Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem. """ config_class: type[BaseConf] base_model_prefix: str config: BaseConf | None = None _model_task: str | None = None _model_type: str | None = None _parameter_transform_rules: tp.ClassVar[list[ParameterTransformRule]] = [] def __init__( self, config: BaseConf, dtype: jnp.dtype, param_dtype: jnp.dtype, precision: lax.PrecisionLike, rngs: nn.Rngs, ): """Initialize the EasyDeLBaseModule. Sets up the base module with configuration and data types. Subclasses should call this in their __init__ method. Args: config: Model configuration with architecture parameters. dtype: Data type for computations (e.g., float32, bfloat16). param_dtype: Data type for model parameters. precision: Precision setting for JAX operations. rngs: Random number generators for initialization. Note: This method should be called by all subclasses to properly initialize the base functionality. """ self.config: BaseConf = config self.dtype: jnp.dtype = dtype self.param_dtype: jnp.dtype = param_dtype self.precision: lax.PrecisionLike = precision self.rngs: nn.Rngs = rngs _ = self.graphtree_shape _ = self.graphtree_params_shape _ = self.mesh _ = self.model_task _ = self.model_type @property def parameters(self: Self) -> dict: """ Retrieves the parameters of the module as a dictionary. This property iterates through the module and its submodules, extracting variables marked as `nn.Param` and returning them in a flat dictionary where keys represent the parameter path. Returns: tp.Dict: A dictionary containing the module's parameters. """ from easydel.utils.traversals import iter_module_search parameters = {} for key, value in iter_module_search(self, nn.Param): parameters[key] = value.value return parameters @property def graphstate_type(self: Self): """Determines the parameter type based on whether LoRA is enabled. Returns: nn.LoRAParam if LoRA is enabled, otherwise nn.Param. """ return nn.LoRAParam if self.lora_is_enabled else nn.Param
[docs] def split_module(self: Self): """Splits the module into graph definition and state components. Returns: Tuple of (GraphDef, GraphState, GraphState) representing structure, parameters, and other state. """ return nn.split(self, self.graphstate_type, ...)
[docs] @staticmethod def merge_module(graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState): """Merges graph components back into a complete module. Args: graphdef: The module's graph definition (structure). graphstate: The module's parameter state. graphother: The module's non-parameter state. Returns: The reconstructed module. """ return nn.merge(graphdef, graphstate, graphother)
@property def graphdef(self: Self) -> nn.GraphDef: """ Returns the graph definition (structure without parameters) of the module. Uses `flax.nnx.split` to separate the graph definition from the state (parameters). Returns: nn.GraphDef: The graph definition of the module. """ return nn.split(self, self.graphstate_type, ...)[0] @property def graphstate(self: Self) -> nn.GraphState: """ Returns the graph state (parameters) of the module. Uses `flax.nnx.split` to separate the state (parameters) from the graph definition. Returns: nn.GraphState: The graph state containing the module's parameters. """ return nn.split(self, self.graphstate_type, ...)[1] @property def graphother(self: Self) -> nn.GraphState: """ Returns any other state variables in the module (non-parameters). Uses `flax.nnx.split` to separate non-parameter state variables. Returns: nn.GraphState: The graph state containing non-parameter variables. """ return nn.split(self, self.graphstate_type, ...)[-1] @property def graphtree_params_shape(self: Self) -> dict: """ Computes and returns the shapes of the module's parameters as a nested dictionary. It uses `nnx.eval_shape` to determine the shapes without actual computation, then extracts the shape information from the resulting graph state. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing their shapes. """ graphtree = nn.eval_shape(lambda: nn.split(self, self.graphstate_type, ...)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: val.value for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def graphtree_shape(self: Self) -> dict: """ Computes and returns the shapes of all state variables (including non-parameters) in the module. Uses `nnx.eval_shape` on the entire module state (parameters and others) and extracts the shape information. Returns: tp.Dict: A nested dictionary mirroring the module's state structure, containing the shapes. """ graphtree = nn.eval_shape(lambda: nn.split(self)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: getattr(val, "value", val) for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def mesh(self: Self) -> jax.sharding.Mesh: """ Retrieves the JAX device mesh from the module's configuration. Returns: jax.sharding.Mesh: The device mesh defined in `self.config.mesh`. """ return self.config.mesh @property def model_task(self: Self) -> str | None: """ Returns the specific task associated with this model instance (e.g., 'causal-language-model'). Returns: tp.Optional[str]: The model task identifier, or None if not set. """ return self._model_task @property def model_type(self: Self) -> str | None: """ Returns the specific type of this model instance (e.g., 'llama', 'mistral'). Returns: tp.Optional[str]: The model type identifier, or None if not set. """ return self._model_type @property def params(self: Self) -> dict: """ Returns the parameters and other state variables of the module as a dictionary. Uses `flax.nnx.split` to get the combined state (parameters and others). Returns: tp.Dict: A dictionary containing all state variables of the module. """ return nn.split(self)[-1] @cached_property def causal_mask(self: Self) -> jnp.ndarray: """ Retrieves or computes the basic causal attention mask from the configuration. Uses `self.config.get_basic_causal_mask()` and caches the result. Returns: jnp.ndarray: The causal attention mask, potentially cached. """ return self.config.get_basic_causal_mask() @cached_property def frequencies(self: Self) -> jnp.ndarray: """ Retrieves or computes the frequency components (e.g., for RoPE) from the configuration. Uses `self.config.get_basic_frequencies()` and caches the result. Returns: jnp.ndarray: The frequency components, potentially cached. """ return self.config.get_basic_frequencies() @cached_property def inv_frequencies(self: Self) -> jnp.ndarray: """ Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration. Uses `self.config.get_basic_inv_frequencies()` and caches the result. Returns: jnp.ndarray: The inv-frequency components, potentially cached. """ return self.config.get_basic_inv_frequencies() @cached_property def static_arguments(self: Self) -> tuple: """ Retrieves or computes static arguments needed for the module's `__call__` method. Uses `self.get_static_arguments()` and caches the result. Static arguments are typically those that don't change during execution and can be pre-computed. Returns: tp.Tuple: A tuple of static arguments. """ return self.get_static_arguments() @cached_property def lossfn_type(self: Self): """Determines the loss function type for this model. Attempts to determine the loss type from (in order): 1. config.loss_type attribute 2. self.loss_type attribute 3. Class name matching 4. Defaults to ForCausalLM if not found Returns: String identifier for the loss function type. """ if getattr(self.config, "loss_type", None) is not None: loss_type = self.config.loss_type elif getattr(self, "loss_type", None) is not None: loss_type = self.loss_type else: loss_type = self.__class__.__name__ if loss_type not in LOSS_MAPPING: loss_groups = f"({'|'.join(LOSS_MAPPING)})" loss_type = re.findall(loss_groups, self.__class__.__name__) if len(loss_type) > 0: loss_type = loss_type[0] else: loss_type = None if loss_type is None or (loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None): warnings.warn( f"`loss_type={loss_type}` was set in the config but it is unrecognised." f"Using the default loss: `ForCausalLMLoss`.", stacklevel=1, ) loss_type = "ForCausalLM" return loss_type @cached_property def loss_function(self: Self): """ Determines and returns the appropriate loss function based on the configuration or model type. It prioritizes `config.loss_type`, then `self.loss_type`, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to `ForCausalLMLoss` and issues a warning. Returns: tp.Callable: The selected loss function (e.g., `ForCausalLMLoss`, `ForSequenceClassificationLoss`). """ return LOSS_MAPPING[self.lossfn_type] @property def module_dtype(self: Self) -> jnp.dtype: """ Determines the data type of the module's parameters. It inspects the flattened parameter state to find the dtype of the first parameter encountered. Returns: jnp.dtype: The data type of the module's parameters. """ params_state = nn.split(self, self.graphstate_type, ...)[1].flat_state() return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def compute_complex_rotary(self, position_ids: jax.Array) -> jnp.ndarray: """Computes complex-valued rotary position embeddings. Args: position_ids: Position indices to compute embeddings for. Returns: Complex exponential of frequencies for rotary embeddings. """ frequencies = jnp.transpose( self.inv_frequencies[None, :, None] @ position_ids[:, None, :].astype("f4"), (0, 2, 1), ) return jnp.exp(1j * frequencies)
[docs] def to_dtype(self: Self, dtype: jnp.dtype) -> Self: """ Converts the module's parameters to the specified data type. It iterates through the module's parameters (excluding quantization-related ones) and casts them to the target `dtype`. It also updates the `param_dtype` attribute of the module and its submodules if they exist. Args: dtype (jnp.dtype): The target data type for the parameters. Returns: Self: The module instance with parameters converted to the specified dtype. """ from easydel.utils.traversals import iter_module_search gdef, state, others = nn.split(self, self.graphstate_type, ...) def _map(path, val: nn.VariableState): if val.value is not None: if not path[-1].startswith("quant_"): val.value = val.value.astype(dtype) return val state.update(state.map(_map)) self = nn.merge(gdef, state, others) for _path, module in iter_module_search(self): if hasattr(module, "param_dtype"): module.param_dtype = dtype return self
[docs] def half(self: Self, change_runtime_dtype: bool = True) -> Self: """ Converts the module's parameters to half-precision (float16). Optionally also changes the runtime computation dtype (`self.dtype`) to float16. Args: change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float16`. Defaults to True. Returns: Self: The module instance with parameters (and potentially runtime dtype) set to float16. """ if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float16) return self._reformat_dtype(jnp.float16)
[docs] def float(self: Self, change_runtime_dtype: bool = True) -> Self: """ Converts the module's parameters to single-precision (float32). Optionally also changes the runtime computation dtype (`self.dtype`) to float32. Args: change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float32`. Defaults to True. Returns: Self: The module instance with parameters (and potentially runtime dtype) set to float32. """ if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float32) return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: Self, dtype) -> Self: """ Internal helper to change the runtime computation data type (`dtype`) of the module and its submodules. Args: dtype (jnp.dtype): The target runtime data type. Returns: Self: The module instance with updated runtime dtype. """ from easydel.utils.traversals import iter_module_search for _path, module in iter_module_search(self): if hasattr(module, "dtype"): if str(type(module.dtype)).endswith("lax_numpy._ScalarMeta'>"): # dont change numpy based dtypes module.dtype = dtype self.dtype = dtype return self def _reformat_dtype(self: Self, dtype) -> Self: """ Internal helper to change the data type of the module's parameters (`param_dtype`). Casts floating-point parameters to the target `dtype`. Args: dtype (jnp.dtype): The target parameter data type. Returns: Self: The module instance with updated parameter dtype. """ from easydel.utils.traversals import iter_module_search gdef, gtree, others = nn.split(self, self.graphstate_type, ...) def _map(array): if array.dtype in [ jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.float_, ]: array = array.astype(dtype) return array gtree = jax.tree_util.tree_map(_map, gtree) others = jax.tree_util.tree_map(_map, others) self = nn.merge(gdef, gtree, others) for _path, module in iter_module_search(self): if hasattr(module, "param_dtype"): if isinstance(module.param_dtype, jnp.dtype): module.param_dtype = dtype self.param_dtype = dtype return self def _match_partition_rules(self, partition_rules: tp.Any = None): """ Matches the provided or configured partition rules against the module's parameter shapes. Args: partition_rules (tp.Any, optional): The partition rules to use. If None, uses rules from the configuration. Defaults to None. Returns: tp.Any: The partition specifications matched to the parameter tree. """ return match_partition_rules( rules=self._get_partition_rules(partition_rules), tree=self.graphtree_params_shape, ) @property def _specs_sharding(self: Self): """ Extracts the PartitionSpec part from the NamedSharding of each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing PartitionSpecs. """ def _map(array): if hasattr(array, "sharding"): sharding = array.sharding if isinstance(sharding, NamedSharding): return sharding.spec return PartitionSpec() return nn.from_tree(jax.tree_util.tree_map(_map, nn.to_tree(self))) @property def _shardings(self: Self): """ Extracts the sharding information (PartitionSpec or NamedSharding) for each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing the sharding info. """ return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(), nn.to_tree(self), ) ) @property def _named_shardings(self: Self): """ Extracts the NamedSharding object (if present) for each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing NamedSharding or None. """ return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, nn.to_tree(self), ) ) def _get_mesh(self, mesh: Mesh | None = None) -> Mesh: """ Retrieves the JAX device mesh, prioritizing the provided argument over the configuration. Args: mesh (tp.Optional[Mesh]): A potential JAX device mesh. Returns: Mesh: The resolved JAX device mesh. Raises: ValueError: If no mesh is provided and none is found in the configuration. """ if mesh is None: if not hasattr(self, "config") or not hasattr(self.config, "mesh") or self.config.mesh is None: raise ValueError("A mesh must be provided, either as an argument or through the model config.") return self.config.mesh return mesh def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike: """ Retrieves the partitioning rules, prioritizing the provided argument over the configuration. Args: partition_rules (PartitionLike): Potential partitioning rules. Returns: PartitionLike: The resolved partitioning rules. Raises: ValueError: If no rules are provided and none can be obtained from the configuration. """ if partition_rules is None: if not hasattr(self, "config"): raise ValueError("Partition rules must be provided either as an argument or through the model config.") return self.config.get_partition_rules(fully_sharded_data_parallel=True) return partition_rules def _apply_sharding_fns( self: Self, sharding_fns: tp.Mapping[str, tp.Callable], ) -> Self: """ Applies sharding or gathering functions to the module's parameters. Args: sharding_fns (tp.Mapping[str, tp.Callable]): A mapping from flattened parameter paths to sharding or gathering functions. Returns: Self: The module instance with sharding/gathering functions applied to its parameters. """ gdef, state, others = nn.split(self, self.graphstate_type, ...) sharding_fns = flatten_dict(sharding_fns) _shard_keys = list(sharding_fns.keys()) def _map(path, val: nn.VariableState): if val.value is not None and path in _shard_keys: fn = sharding_fns[path] val.value = fn(val.value) return val state.update(state.map(_map)) others.update(others.map(_map)) self = nn.merge(gdef, state, others) return self
[docs] def shard_model( self: Self, partition_rules: PartitionLike = None, mesh: Mesh | None = None, overlay_fns: tp.Mapping[str, tp.Callable] | None = None, ) -> Self: """ Shards the model's parameters according to the specified rules and mesh. Args: partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules. Defaults to None. mesh (tp.Optional[Mesh], optional): JAX device mesh. If None, uses config mesh. Defaults to None. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None. Returns: Self: The sharded model instance. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules(rules=partition_rules, tree=self.graphtree_params_shape) shard_fns, _ = make_shard_and_gather_fns(partition_specs=partition_specs, mesh=mesh) if overlay_fns is not None: shard_fns.update(overlay_fns) self = self._apply_sharding_fns(shard_fns) return self
[docs] def gather_model( self: Self, partition_rules: PartitionLike = None, mesh: Mesh | None = None, overlay_fns: tp.Mapping[str, tp.Callable] | None = None, ) -> Self: """ Gathers the model's parameters from potentially distributed devices to the host or a single device. Args: partition_rules (PartitionLike, optional): Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None. mesh (tp.Optional[Mesh], optional): JAX device mesh from which to gather. If None, uses config mesh. Defaults to None. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None. Returns: Self: The model instance with gathered parameters. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules( rules=partition_rules, tree=self.graphtree_params_shape, ) _, gather_fns = make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, ) if overlay_fns is not None: gather_fns.update(overlay_fns) return self._apply_sharding_fns(gather_fns)
@property def _shard_fns(self: Self): """ Generates the dictionary of sharding functions based on the module's configuration. Returns: tp.Mapping: A mapping from flattened parameter paths to sharding functions. """ mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[0] @property def _gather_fns(self: Self): """ Generates the dictionary of gathering functions based on the module's configuration. Returns: tp.Mapping: A mapping from flattened parameter paths to gathering functions. """ mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[1]
[docs] def apply_out_shardings(self, out_shardings): """Applies output sharding specifications to the module state. Args: out_shardings: Sharding specifications to apply. Returns: Module with sharding constraints applied. """ splits = self.split_module() @partial(jax.jit, out_shardings=out_shardings) def _call(graphstate, graphother): return graphstate, graphother splits[1:] = _call(*splits[1:]) return self.merge_module(*splits)
[docs] def fully_shard(self: Self, partition_rules: PartitionLike = None) -> Self: """ Applies JAX sharding constraints to all parameters based on the partition rules. This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses `ejit` with `out_shardings` to enforce the constraints. Args: partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules. Defaults to None. Returns: Self: The model instance with sharding constraints applied. """ class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=x), match_partition_rules(self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock)), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def fully_gather(self: Self) -> Self: """ Applies JAX sharding constraints to gather all parameters onto the host or a single device. This function marks all parameters to have no sharding (PartitionSpec()). It uses `ejit` with `out_shardings` to enforce these gathering constraints. Returns: Self: The model instance with gathering constraints applied. """ class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()), match_partition_rules(self._get_partition_rules(None), nn.eval_shape(lambda: mock)), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def quantize( self: Self, quantization_config: EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, verbose: bool | None = None, ) -> Self: """ Applies quantization to the module's linear layers or tensors. Args: quantization_config: The quantization configuration specifying dtype, block_size, and pattern. If None, uses default INT8 quantization. quantize_tensors: If True, quantizes the tensor values directly. If False, replaces Linear layers with their quantized equivalents. verbose: If True, logs information during the quantization process. Defaults to True only on process index 0. Returns: Self: The quantized model instance. Example: >>> from easydel.layers.quantization import EasyDeLQuantizationConfig, QuantizationType >>> config = EasyDeLQuantizationConfig(dtype=QuantizationType.NF4, block_size=64) >>> quantized_model = model.quantize(quantization_config=config) """ from easydel.layers.quantization import EasyDeLQuantizationConfig, EasyQuantizer, QuantizationType if quantization_config is None: quantization_config = EasyDeLQuantizationConfig(dtype=QuantizationType.INT8) quantizer = EasyQuantizer(quantization_config=quantization_config) if verbose is None: verbose = jax.process_index() == 0 if quantize_tensors: ... else: self = quantizer.quantize_linears(self, verbose=verbose) return self
[docs] def to_state(self, state_class: type[EasyDeLState] | None = None) -> EasyDeLState: """Convert the current module instance into an EasyDeLState object. This is useful for saving and managing the model's state, including parameters and potentially optimizer state (though optimizer state is typically added later). Args: state_class: Optional custom state class to use. If None, defaults to EasyDeLState. Must be a subclass of EasyDeLState. Returns: EasyDeLState: An EasyDeLState object representing the current model state, with step initialized to 0. """ if state_class is None: from easydel.infra.base_state import EasyDeLState state_class = EasyDeLState @partial(jax.jit, donate_argnums=(1, 2), static_argnums=(0,)) def _create_state(gstruct, gstate, gother): return state_class.create( step=0, model=self.merge_module(gstruct, gstate, gother), ) return _create_state(*self.split_module())
[docs] def to_torch(self, **kwargs): """ Converts the EasyDeL module to its equivalent Hugging Face PyTorch model. Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format. Args: **kwargs: Additional keyword arguments passed to the parameter transformation function. Returns: torch.nn.Module: The equivalent Hugging Face PyTorch model with loaded weights. """ from easydel.utils.parameters_transformation import ModelConverter return ModelConverter.easydel_to_huggingface( module=self, base_huggingface_module=self.get_torch_loader()._model_mapping[type(self.config)], config=self.config, dtype=self.param_dtype, reform_param=self._get_reform_param(), **kwargs, )
[docs] def prepare_inputs_for_call(self, **kwargs): """ Prepares keyword arguments before passing them to the module's `__call__` method. This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation). Args: **kwargs: The keyword arguments intended for `__call__`. Returns: dict: The prepared keyword arguments. """ return kwargs
[docs] def get_static_arguments(self: Self) -> tuple: """ Returns a tuple of static arguments required by the module's `__call__` method. Static arguments are those that don't change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments. Returns: tp.Tuple: A tuple containing static arguments. """ return ()
[docs] def get_encoder(self: Self) -> nn.Module | EasyDeLBaseModule: """Return the encoder component of the model. This method should be overridden by encoder-decoder models to return their encoder component. Useful for tasks that only need the encoder, such as feature extraction or embedding generation. Returns: nn.Module | EasyDeLBaseModule: The encoder module. Raises: NotImplementedError: If the model does not implement an encoder. """ raise NotImplementedError()
[docs] def get_decoder(self: Self) -> nn.Module | EasyDeLBaseModule: """Return the decoder component of the model. This method should be overridden by encoder-decoder models to return their decoder component. Useful for tasks that need access to the decoder separately from the encoder. Returns: nn.Module | EasyDeLBaseModule: The decoder module. Raises: NotImplementedError: If the model does not implement a decoder. """ raise NotImplementedError()
[docs] def get_lm_head(self: Self) -> ParallelLinear: """Return the language model head of the model. This method should be overridden by language models to return their output projection layer that maps hidden states to vocabulary logits. Returns: ParallelLinear: The language model head layer. Raises: NotImplementedError: If the model does not have a language model head. """ raise NotImplementedError()
[docs] def get_embedding(self: Self) -> nn.Module | nn.Embed: """Return the input embedding layer of the model. This method should be overridden by models to return their token embedding layer. Useful for weight tying or accessing embeddings directly. Returns: nn.Module | nn.Embed: The embedding layer. Raises: NotImplementedError: If the model does not have an embedding layer. """ raise NotImplementedError()
[docs] @classmethod def sequential_init(cls: type[Self], **kwargs) -> Self: """Initialize model parameters sequentially with proper sharding. This method performs lazy initialization followed by sequential parameter initialization with appropriate sharding for distributed training. It's particularly useful for large models that need memory-efficient initialization. The method: 1. Creates a lazy (shape-only) version of the model 2. Iterates through all modules and initializes their parameters 3. Applies proper sharding based on partition rules Args: **kwargs: Arguments passed to lazy_init, including: - config: Model configuration - dtype: Computation dtype - param_dtype: Parameter dtype - precision: JAX precision setting - rngs: Random number generators (required) Returns: Self: Fully initialized model with properly sharded parameters. Example: >>> config = LlamaConfig(hidden_size=1024, num_hidden_layers=4) >>> model = LlamaModel.sequential_init( ... config=config, ... dtype=jnp.float32, ... param_dtype=jnp.float32, ... rngs=nn.Rngs(0) ... ) """ from easydel.utils.traversals import iter_module_search def _shard(x): return x rng = kwargs.get("rngs", flax.nnx.Rngs(44)) lazy_model = cls.lazy_init(**kwargs) partition_rules = lazy_model.config.get_partition_rules() for path, module in iter_module_search(lazy_model, (flax.nnx.Module, ArrayParam)): if path: joined_path = "/".join([str(p) for p in path]) a = jnp.ones((1,)) partition_spec = jax.tree_util.tree_map( lambda x: NamedSharding(lazy_model.mesh, x), match_partition_rules( partition_rules, { joined_path + "/kernel": a, joined_path + "/bias": a, joined_path + "/embedding": a, joined_path + "/scale": a, joined_path: a, }, strict=False, ), ) shardings = { "kernel": partition_spec[joined_path + "/kernel"], "bias": partition_spec[joined_path + "/bias"], "embedding": partition_spec[joined_path + "/embedding"], "scale": partition_spec[joined_path + "/scale"], "raw": partition_spec[joined_path], } if hasattr(module, "kernel") and hasattr(module, "kernel_init"): arr = module.kernel_init( key=rng.param(), shape=module.kernel.value.shape, dtype=module.kernel.value.dtype, ) arr = jax.jit(_shard, out_shardings=shardings["kernel"])(arr) if isinstance(module.kernel, flax.nnx.Param): module.kernel.value = arr else: module.kernel = arr if hasattr(module, "bias") and hasattr(module, "bias_init") and module.bias is not None: arr = module.bias_init( key=rng.param(), shape=module.bias.value.shape, dtype=module.bias.value.dtype, ) arr = jax.jit(_shard, out_shardings=shardings["bias"])(arr) module.bias.value = arr if hasattr(module, "embedding") and hasattr(module, "embedding_init"): arr = module.embedding_init( key=rng.param(), shape=module.embedding.value.shape, dtype=module.embedding.value.dtype, ) arr = jax.jit(_shard, out_shardings=shardings["embedding"])(arr) module.embedding.value = arr if hasattr(module, "scale") and hasattr(module, "scale_init"): arr = module.scale_init( key=rng.param(), shape=module.scale.value.shape, dtype=module.scale.value.dtype, ) arr = jax.jit(_shard, out_shardings=shardings["scale"])(arr) module.scale.value = arr if hasattr(module, "rngs"): module.rngs = rng.fork() if hasattr(module, "resure"): module.resure(rng.param(), shard_fn=jax.jit(_shard, out_shardings=shardings["raw"])) for path, module in iter_module_search(lazy_model, nn.Param): if path and type(module.value) is jax.ShapeDtypeStruct: logger.warning(f"({type(module).__name__}) found empty array at " + ("/".join([str(s) for s in path]))) return lazy_model
[docs] @classmethod def lazy_init(cls: type[Self], **kwargs) -> Self: """ Performs a "lazy" initialization using `nnx.eval_shape`. This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding. Args: **kwargs: Keyword arguments passed to the class constructor. Returns: Self: A module instance with initialized structure but potentially abstract parameters. """ rngs = kwargs.pop("rngs", None) def _init(rngs): return cls(**kwargs, rngs=rngs) return nn.eval_shape(_init, rngs=rngs)
[docs] def merge_lora_params(self: Self, pytree: dict) -> Self: """ Merges LoRA parameters from a pytree into the base model's parameters. Args: pytree (tp.Dict): A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model's parameters. Returns: Self: The module instance with LoRA parameters merged into the base weights. """ from easydel.infra.utils import merge_lora_params self = merge_lora_params(self, pytree) return self
[docs] def split_lora_params(self: Self) -> dict: """ Splits merged LoRA parameters back out from the base model's parameters. This function assumes LoRA parameters were previously merged using `merge_lora_params` or a similar process that stored the original base weights and LoRA weights appropriately. Returns: tp.Dict: A pytree containing the extracted LoRA parameters (A and B matrices). The base model parameters are restored to their original (pre-merge) state. """ from easydel.infra.utils import split_lora_params pytree = split_lora_params(self) return pytree
@property def lora_is_enabled(self: Self): """Checks if LoRA (Low-Rank Adaptation) is enabled for this module. Returns: True if any LoRA parameters are found in the module, False otherwise. """ for _, tensor in nn.iter_graph(self): if isinstance(tensor, nn.LoRAParam): return True return False @property def is_quantized(self) -> bool: """Check if the model contains any quantized layers or parameters. Iterates through the model graph to detect quantized components, including 8-bit linear layers, NF4 linear layers, and quantized arrays. Returns: bool: True if the model contains any quantized components, False otherwise. """ from easydel.layers.quantization import Array8B, ArrayNF4, Linear8bit, LinearNF4 for _, tensor in nn.iter_graph(self): if isinstance(tensor, (Linear8bit, LinearNF4)): return True if isinstance(getattr(tensor, "value", getattr(tensor, "raw_value", tensor)), (Array8B, ArrayNF4)): return True return False
[docs] def apply_lora_to_layers( self: Self, lora_rank: int, lora_pattern: str | None = None, verbose: bool = False, rngs: nn.Rngs | None = None, ) -> Self: """ Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module. Args: lora_rank (int): The rank of the LoRA decomposition. lora_pattern (tp.Optional[str], optional): A regular expression to match the names of the `Linear` layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None. verbose (bool, optional): If True, prints information about which layers are being modified. Defaults to False. rngs (tp.Optional[nn.Rngs], optional): JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None. Returns: Self: The module instance with LoRA layers applied. """ from easydel.infra.utils import apply_lora_to_layers self = apply_lora_to_layers( self, lora_pattern=lora_pattern, lora_rank=lora_rank, rngs=rngs, verbose=verbose, ) return self
[docs] def unwrap_lora_to_layers(self: Self, verbose: bool = False) -> Self: """ Reverts the application of LoRA layers, restoring the original linear layers. Replaces `easydel.layers.lora.LoraLinear` layers with their original `flax.nnx.Linear` counterparts, discarding the LoRA matrices. Args: verbose (bool, optional): If True, prints information about which layers are being reverted. Defaults to False. Returns: Self: The module instance with LoRA layers removed and original layers restored. """ from easydel.infra.utils import unwrap_lora_to_layers self = unwrap_lora_to_layers(self, verbose=verbose) return self
def _get_reform_param(self) -> dict[str, tp.Any]: """ Traverses the module tree to collect `reform_param` configurations from sub-modules. Returns: dict[str, tp.Any]: A dictionary mapping fully qualified parameter paths to their reform configuration. """ from easydel.utils import traversals reform_param = {} for path, module in traversals.iter_module_search(self, nn.Module): if hasattr(module, "reform_param") and module.reform_param: path_str = ".".join(map(str, path)) for key, value in module.reform_param.items(): full_key = f"{path_str}.{key}" if path_str else key new_value = value.copy() new_splits = [] for split in value["splits"]: new_split = split.copy() split_name = split["name"] new_split["name"] = f"{path_str}.{split_name}" if path_str else split_name new_splits.append(new_split) new_value["splits"] = new_splits reform_param[full_key] = new_value return reform_param @property def transform_fn(self): """Creates a transformation function for converting HuggingFace state dicts to EasyDeL format. Identifies special layers (embeddings, LayerNorm, MoE) and returns a configured transformation function with sharding rules applied. Returns: Partial function for state dict conversion with layer information and sharding. """ from easydel.layers.moe import BaseMoeModule, ParallelMoELinear from easydel.utils import traversals from easydel.utils.parameters_transformation import StateDictConverter embedding_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, nn.Embed)] layernorm_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, nn.LayerNorm)] moe_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, ParallelMoELinear)] moe_block_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, BaseMoeModule)] return partial( StateDictConverter.huggingface_to_easydel, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, moe_names=list(set([names.split(".")[-1] for names in moe_path])), moe_block_names=list(set([names.split(".")[-1] for names in moe_block_path])), moe_block_path=moe_block_path, moe_path=moe_path, dtype=self.param_dtype, shard_fns=self._shard_fns, reform_param=self._get_reform_param(), ) @property def _generate_compatible_graphdef(self: Self): """ Creates a graph definition compatible with generation tasks. Often, generation requires specific configurations (like disabling gradient checkpointing). This method creates a temporary, generation-compatible configuration, performs a lazy initialization with it, and extracts the resulting graph definition. Returns: nn.GraphDef: A graph definition suitable for use during generation. """ adjusted_config = deepcopy(self.config) adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE dummy = type(self).lazy_init( config=adjusted_config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) gdef, _, _ = nn.split(dummy, self.graphstate_type, ...) return gdef @property def _generate_compatible_graphother(self: Self): """ Creates the 'other' state (non-parameters) compatible with generation tasks. Similar to `_generate_compatible_graphdef`, this creates a temporary, generation-compatible configuration, lazy-initializes, and extracts the 'other' state variables, ensuring they have concrete values instead of meta-placeholders. Returns: nn.GraphState: A graph state containing non-parameter variables suitable for generation. """ adjusted_config = deepcopy(self.config) adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE dummy = type(self).lazy_init( config=adjusted_config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) _, _, gother = nn.split(dummy, self.graphstate_type, ...) gother = traversals.recreate_meta_values(gother) return gother @property def params_sharding(self: Self) -> dict: """ Retrieves the sharding annotation for each parameter in the module. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing the sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded. """ return jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, self.split_params_dict(), )
[docs] def merge_params(self, tree): """ Merges a given parameter state tree back into the module. Reconstructs the module using its existing graph definition and 'other' state, but replaces the parameter state with the provided `tree`. Args: tree: A pytree (likely a `nn.GraphState`) containing the parameters to merge. Returns: EasyDeLBaseModule: The module instance with the new parameters merged in. """ gdef, _, gother = nn.split(self, self.graphstate_type, ...) self = nn.merge(gdef, tree, gother) return self
[docs] def split_params(self: Self): """ Splits the module and returns the parameter state. Uses `nnx.split` to extract the `GraphState` containing the parameters. Returns: nn.GraphState: The parameter state of the module. """ return nn.split(self, self.graphstate_type, ...)[1]
[docs] def split_params_dict( self, extract_fn: tp.Callable | None = None, remove_none: bool = True, ) -> dict: """ Splits the module parameters and returns them as a nested dictionary. Extracts the parameter state, converts it to a plain dictionary (removing `VariableState` wrappers), and optionally removes entries with `None` values. Args: extract_fn (tp.Optional[tp.Callable], optional): A function to apply to each parameter during extraction. Defaults to None. remove_none (bool, optional): If True, removes key-value pairs where the value is `None`. Defaults to True. Returns: tp.Dict: A nested dictionary containing the module's parameters. """ flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn)) if remove_none: flat_params = { k: v.value if hasattr(v, "value") else v for k, v in flat_params.items() if (v.value if hasattr(v, "value") else v) is not None } else: flat_params = {k: v.value if hasattr(v, "value") else v for k, v in flat_params.items()} return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: Self, params_dict: dict) -> Self: """ Merges parameters from a dictionary back into the module's state. Updates the module's current parameter state with values from the provided dictionary. Args: params_dict (tp.Dict): A nested dictionary containing the parameters to merge. The structure should match the module's parameter structure. Returns: Self: The module instance with the parameters from the dictionary merged in. Raises: KeyError: If a key from `params_dict` is not found in the module's current state. """ current_state = self.split_params().flat_state() if not is_flatten(params_dict): params_dict = flatten_dict(params_dict) for key, value in params_dict.items(): if key in current_state: current_state[key].value = value else: raise KeyError(f"Parameter key {key} not found in the current model state.") self = self.merge_params(unflatten_dict(current_state)) return self
[docs] def flops_per_token( self, sequence_length: int | None = None, include_loss: bool = True, include_backward: bool = False, ) -> float: """ Calculates the total FLOPs (Floating Point Operations) for the module per token. This method should be implemented by subclasses to provide a module-specific FLOPs calculation. Returns: float: The total FLOPs for the module. Raises: NotImplementedError: If the method is not implemented by the subclass. """ from .utils import ActivationType, FlopCalcConfig, flops_per_token try: config = self.config text_config = getattr(config, "text_config", config) vision_config = getattr(config, "vision_config", config) if sequence_length is None: sequence_length = text_config.granted_mask_max_position_embedding num_heads = text_config.num_attention_heads hidden_dim = text_config.hidden_size fconf = FlopCalcConfig( hidden_dim=hidden_dim, intermediate_dim=text_config.intermediate_size, num_layers=text_config.num_hidden_layers, num_heads=num_heads, activation_type=getattr(text_config, "hidden_act", ActivationType.SILU), head_dim=getattr(text_config, "head_dim", hidden_dim // num_heads), kv_heads=getattr(text_config, "num_key_value_heads", num_heads), seq_len=sequence_length, task=self._model_task, vocab_size=text_config.vocab_size, include_loss=include_loss, num_labels=getattr(text_config, "num_labels", 0), num_experts=getattr(text_config, "num_local_experts", 0), num_experts_per_tok=getattr(text_config, "num_experts_per_tok", 0), glu=getattr(text_config, "glu_mlp", True), vision_hidden_dim=getattr(vision_config, "hidden_size", 0), vision_intermediate_dim=getattr(vision_config, "intermediate_size", 0), vision_num_heads=getattr(vision_config, "num_attention_heads", 0), vision_num_layers=getattr(vision_config, "num_hidden_layers", 0), vision_seq_len=getattr(vision_config, "max_position_embeddings", 0), ) flops = flops_per_token(fconf) if include_backward: flops *= 3 except Exception: warnings.warn("Calculating Flops Failed!", stacklevel=1) flops = 1 return flops
def _flop(self, *args, **kwargs) -> float | None: """ Estimates the FLOPs (Floating Point Operations) for a single forward pass (`__call__`). Uses JAX's `make_jaxpr` to get the computation graph and then analyzes it using `easydel.infra.utils.count_flop_jaxpr` to estimate FLOPs. Args: *args: Positional arguments to pass to `__call__`. **kwargs: Keyword arguments to pass to `__call__`. Returns: tp.Optional[float]: The estimated FLOP count, or None if calculation fails. """ from .utils import count_flop_jaxpr return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs)) @property def pure_transform_fn(self: Self): """ Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters. Similar to `transform_fn`, but this version does *not* include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (`torch_dict_to_easydel_params`) configured only with layer names and dtype. Returns: tp.Callable: A partial function for converting a PyTorch state dict without applying sharding. """ from easydel.layers.moe import BaseMoeModule, ParallelMoELinear from easydel.utils import traversals from easydel.utils.parameters_transformation import StateDictConverter embedding_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, nn.Embed)] layernorm_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, nn.LayerNorm)] moe_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, ParallelMoELinear)] moe_block_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(self, BaseMoeModule)] return partial( StateDictConverter.huggingface_to_easydel, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, moe_names=list(set([names.split(".")[-1] for names in moe_path])), moe_block_names=list(set([names.split(".")[-1] for names in moe_block_path])), moe_block_path=moe_block_path, moe_path=moe_path, dtype=self.param_dtype, reform_param=self._get_reform_param(), ) @property def _default_loss_config(self: Self) -> LossConfig | None: """ Provides a default LossConfig for the module, if applicable. Subclasses can override this property to return a default `LossConfig` instance specific to the model's task (e.g., setting `num_labels` for sequence classification). Returns: tp.Optional[LossConfig]: The default loss configuration, or None. """ return None @_default_loss_config.setter def _default_loss_config(self, val): """Setter for the default loss config (internal use).""" return val
[docs] def compute_loss( self, *, labels: chex.Array | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None, **batch, ) -> tuple[tp.Any, LossMetrics]: """ Computes the loss for the model given a batch of inputs and labels. This method performs a forward pass using the provided `batch` arguments, then calculates the loss using the determined `loss_function`. It handles potential label inference (e.g., using `input_ids` as labels for Causal LM) and default loss configurations. Args: labels (tp.Optional[chex.Array], optional): The target labels. If None and the task is Causal LM, `input_ids` from the batch might be used. Defaults to None. loss_config (tp.Optional[LossConfig], optional): Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None. loss_kwargs (tp.Optional[tp.Dict], optional): Additional keyword arguments to pass directly to the loss function. Defaults to None. **batch: Keyword arguments representing the input batch (e.g., `input_ids`, `attention_mask`). Returns: tp.Tuple[tp.Any, LossMetrics]: A tuple containing: - The model's output ( Pytree typically including logits, hidden states etc.) - A `LossMetrics` object containing the calculated loss and potentially other metrics. Raises: AssertionError: If labels are required for the loss function but are not provided or inferred. AssertionError: If sequence classification loss is used without `num_labels` in the config. """ if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__: labels = batch.get("input_ids", None) if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__: if loss_config is None: assert hasattr(self.config, "num_labels"), ( "in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach" " `num_labels` to model `config`" ) loss_config = LossConfig(num_labels=self.config.num_labels) assert labels is not None, "`labels` can not be `None` for computing loss." loss_kwargs = loss_kwargs or {} outputs = self(**batch) loss_output: LossMetrics = self.loss_function( labels=labels, config=loss_config, paxis=self.config.partition_axis, **loss_kwargs, **outputs, **batch, ) if hasattr(outputs, "aux_loss"): if outputs.aux_loss is not None: loss_output.loss = loss_output.loss + outputs.aux_loss outputs = outputs.replace(loss=loss_output.loss) return outputs, loss_output
[docs] def apply_lm_head(self, hidden_states: chex.Array) -> chex.Array: """ Apply the language model head to transform hidden states into logits. Args: hidden_states: Input hidden states from the transformer model. Shape should be [..., hidden_size]. Returns: Output logits over the vocabulary. Shape will be [..., vocab_size]. """ tie_embeddings = next( ( getattr(self.config, key) for key in ["tie_word_embeddings", "use_lm_head", "share_input_output_layers"] if hasattr(self.config, key) ), False, ) w = self.get_embedding().embedding.value.T if tie_embeddings else None return self.get_lm_head()(hidden_states, w=w)
[docs] def update_module(self, **kwargs: Unpack[EasyDeLBaseConfigDict]): """Updates the module configuration and reinitializes the structure. Creates a new lazy module with updated configuration while preserving the current parameter state. Args: **kwargs: Configuration parameters to update. Returns: Updated module with new configuration. """ config = self.config for k, v in kwargs.items(): setattr(config, k, v) module = self.lazy_init( config=config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) self = self.merge_module(module.graphdef, self.graphstate, self.graphother) return self
[docs] def new_graphdef(self, **kwargs: Unpack[EasyDeLBaseConfigDict]): """Create a new module with updated configuration. Creates a new lazy module with updated configuration while preserving the current parameter state. This is useful for modifying model behavior without reinitializing weights. Args: **kwargs: Configuration parameters to update. These will be applied to a copy of the current configuration. Returns: EasyDeLBaseModule: A new module instance with updated configuration and the same parameter values. """ config = deepcopy(self.config) for k, v in kwargs.items(): setattr(config, k, v) module = self.lazy_init( config=config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) return module.graphdef
def __hash__(self): """Computes a hash of the module for caching and comparison purposes. Returns: Hash value based on the module's state and configuration. """ return self.static_hash(None)
[docs] def static_hash(self, pop_things: list[str] | None = None): """Computes a deterministic hash of the module's state and configuration. This method creates a hash based on the module's parameters (graphstate), non-parameter state (graphother), and configuration dictionary. It's useful for caching compiled functions or identifying when model states differ. Args: pop_things: Optional list of configuration keys to exclude from the hash. Useful when you want to ignore certain config fields (e.g., 'attn_mechanism') that shouldn't affect the cache key. Returns: A signed integer hash value computed from the model's state and configuration. Example: >>> # Hash without excluding any config keys >>> hash1 = model.static_hash() >>> >>> # Hash excluding attention mechanism from consideration >>> hash2 = model.static_hash(["attn_mechanism"]) >>> >>> # These may be equal if only attn_mechanism differs >>> hash1 == hash2 # True if configs match except attn_mechanism Note: The hash is computed using MD5 on the serialized state signature and configuration dictionary, ensuring deterministic results for identical states. """ from ejkernel.callib._ejit import _get_args_signature dict_config = self.config.to_dict() if pop_things: for pops in pop_things: dict_config.pop(pops) tree_hash = _get_args_signature((self.graphstate, self.graphother), dict_config) bytes_in = hashlib.md5((tree_hash).encode("utf-8")).digest() return int.from_bytes(bytes_in, byteorder="big", signed=True)