Source code for easydel.__init__.infra.base_module

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import typing as tp
import warnings
from functools import cached_property, partial

import chex
import flax
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from easydel.utils import traversals
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict

from .base_config import EasyDeLBaseConfig
from .etils import EasyDeLGradientCheckPointers, EasyDeLQuantizationMethods
from .loss_utils import (
	LOSS_MAPPING,
	ForCausalLMLoss,
	ForSequenceClassificationLoss,
	LossConfig,
	LossMetrics,
)
from .mixins import (
	BaseModuleProtocol,
	EasyBridgeMixin,
	EasyGenerationMixin,
)

if tp.TYPE_CHECKING:
	from easydel.infra.base_state import EasyDeLState
else:
	EasyDeLState = tp.Any

PartitionLike = tp.Optional[
	tp.Union[
		tp.Mapping[str, tp.Callable],
		tp.Mapping[tuple, tp.Callable],
	]
]


logger = get_logger(__name__)

_CP = tp.TypeVar("CP")
SELF = tp.TypeVar("SELF")


[docs]class EasyDeLBaseModule( nn.Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin, ): """ Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem. """ config_class: tp.Type[EasyDeLBaseConfig] base_model_prefix: str _model_task: tp.Optional[str] = None _model_type: tp.Optional[str] = None def __init__( self, config: tp.Union[EasyDeLBaseConfig, _CP], dtype: jnp.dtype, param_dtype: jnp.dtype, precision: lax.PrecisionLike, rngs: nn.Rngs, ): """Initializes the EasyDeLBaseModule. Args: config (EasyDeLBaseConfig): The model configuration. dtype (jnp.dtype): The data type for computation. param_dtype (jnp.dtype): The data type for parameters. precision (jax.lax.PrecisionLike): The numerical precision. rngs (nn.Rngs): The random number generators. """ self.config: tp.Union[EasyDeLBaseConfig, _CP] = config self.dtype: jnp.dtype = dtype self.param_dtype: jnp.dtype = param_dtype self.precision: lax.PrecisionLike = precision self.rngs: nn.Rngs = rngs # these useless call's are just here to init values in graphdef _ = self.graphtree_shape _ = self.graphtree_params_shape _ = self.mesh _ = self.model_task _ = self.model_type @property def parameters(self) -> tp.Dict: """ Retrieves the parameters of the module as a dictionary. This property iterates through the module and its submodules, extracting variables marked as `nn.Param` and returning them in a flat dictionary where keys represent the parameter path. Returns: tp.Dict: A dictionary containing the module's parameters. """ from easydel.utils.graph_utils import iter_module_search parameters = {} for key, value in iter_module_search(self, nn.Param): parameters[key] = value.value return parameters @property def graphdef(self) -> nn.GraphDef: """ Returns the graph definition (structure without parameters) of the module. Uses `flax.nnx.split` to separate the graph definition from the state (parameters). Returns: nn.GraphDef: The graph definition of the module. """ return nn.split(self, nn.Param, ...)[0] @property def graphstate(self) -> nn.GraphState: """ Returns the graph state (parameters) of the module. Uses `flax.nnx.split` to separate the state (parameters) from the graph definition. Returns: nn.GraphState: The graph state containing the module's parameters. """ return nn.split(self, nn.Param, ...)[1] @property def graphother(self) -> nn.GraphState: """ Returns any other state variables in the module (non-parameters). Uses `flax.nnx.split` to separate non-parameter state variables. Returns: nn.GraphState: The graph state containing non-parameter variables. """ return nn.split(self, nn.Param, ...)[-1] @property def graphtree_params_shape(self) -> tp.Dict: """ Computes and returns the shapes of the module's parameters as a nested dictionary. It uses `nnx.eval_shape` to determine the shapes without actual computation, then extracts the shape information from the resulting graph state. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing their shapes. """ graphtree = nn.eval_shape(lambda: nn.split(self, nn.Param, ...)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: val.value for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def graphtree_shape(self) -> tp.Dict: """ Computes and returns the shapes of all state variables (including non-parameters) in the module. Uses `nnx.eval_shape` on the entire module state (parameters and others) and extracts the shape information. Returns: tp.Dict: A nested dictionary mirroring the module's state structure, containing the shapes. """ graphtree = nn.eval_shape(lambda: nn.split(self)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: val.value for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def mesh(self) -> jax.sharding.Mesh: """ Retrieves the JAX device mesh from the module's configuration. Returns: jax.sharding.Mesh: The device mesh defined in `self.config.mesh`. """ return self.config.mesh @property def model_task(self) -> tp.Optional[str]: """ Returns the specific task associated with this model instance (e.g., 'causal-language-model'). Returns: tp.Optional[str]: The model task identifier, or None if not set. """ return self._model_task @property def model_type(self) -> tp.Optional[str]: """ Returns the specific type of this model instance (e.g., 'llama', 'mistral'). Returns: tp.Optional[str]: The model type identifier, or None if not set. """ return self._model_type @property def params(self) -> tp.Dict: """ Returns the parameters and other state variables of the module as a dictionary. Uses `flax.nnx.split` to get the combined state (parameters and others). Returns: tp.Dict: A dictionary containing all state variables of the module. """ return nn.split(self)[-1] @cached_property def causal_mask(self) -> jnp.ndarray: """ Retrieves or computes the basic causal attention mask from the configuration. Uses `self.config.get_basic_causal_mask()` and caches the result. Returns: jnp.ndarray: The causal attention mask, potentially cached. """ return self.config.get_basic_causal_mask() @cached_property def frequencies(self) -> jnp.ndarray: """ Retrieves or computes the frequency components (e.g., for RoPE) from the configuration. Uses `self.config.get_basic_frequencies()` and caches the result. Returns: jnp.ndarray: The frequency components, potentially cached. """ return self.config.get_basic_frequencies() @cached_property def inv_frequencies(self) -> jnp.ndarray: """ Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration. Uses `self.config.get_basic_inv_frequencies()` and caches the result. Returns: jnp.ndarray: The inv-frequency components, potentially cached. """ return self.config.get_basic_inv_frequencies() @cached_property def static_arguments(self) -> tp.Tuple: """ Retrieves or computes static arguments needed for the module's `__call__` method. Uses `self.get_static_arguments()` and caches the result. Static arguments are typically those that don't change during execution and can be pre-computed. Returns: tp.Tuple: A tuple of static arguments. """ return self.get_static_arguments() @cached_property def loss_function(self): """ Determines and returns the appropriate loss function based on the configuration or model type. It prioritizes `config.loss_type`, then `self.loss_type`, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to `ForCausalLMLoss` and issues a warning. Returns: tp.Callable: The selected loss function (e.g., `ForCausalLMLoss`, `ForSequenceClassificationLoss`). """ if getattr(self.config, "loss_type", None) is not None: loss_type = self.config.loss_type elif getattr(self, "loss_type", None) is not None: loss_type = self.loss_type else: loss_type = self.__class__.__name__ if loss_type not in LOSS_MAPPING: loss_groups = f"({'|'.join(LOSS_MAPPING)})" loss_type = re.findall(loss_groups, self.__class__.__name__) if len(loss_type) > 0: loss_type = loss_type[0] else: loss_type = None if ( loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None ): warnings.warn( f"`loss_type={loss_type}` was set in the config but it is unrecognised." f"Using the default loss: `ForCausalLMLoss`.", stacklevel=1, ) loss_type = "ForCausalLM" return LOSS_MAPPING[loss_type] @property def module_dtype(self) -> jnp.dtype: """ Determines the data type of the module's parameters. It inspects the flattened parameter state to find the dtype of the first parameter encountered. Returns: jnp.dtype: The data type of the module's parameters. """ params_state = nn.split(self, nn.Param, ...)[1].flat_state() return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def compute_complex_rotary(self, position_ids: jax.Array) -> jnp.ndarray: frequencies = jnp.transpose( self.inv_frequencies[None, :, None] @ position_ids[:, None, :].astype("f4"), (0, 2, 1), ) return jnp.exp(1j * frequencies)
[docs] def to_dtype(self: SELF, dtype: jnp.dtype) -> SELF: """ Converts the module's parameters to the specified data type. It iterates through the module's parameters (excluding quantization-related ones) and casts them to the target `dtype`. It also updates the `param_dtype` attribute of the module and its submodules if they exist. Args: dtype (jnp.dtype): The target data type for the parameters. Returns: SELF: The module instance with parameters converted to the specified dtype. """ from easydel.utils.graph_utils import iter_module_search gdef, state, others = nn.split(self, nn.Param, ...) def _map(path, val: nn.VariableState): if val.value is not None: if not path[-1].startswith("quant_"): val.value = val.value.astype(dtype) return val state.update(state.map(_map)) self = nn.merge(gdef, state, others) for path, module in iter_module_search(self): if hasattr(module, "param_dtype"): module.param_dtype = dtype return self
[docs] def half(self: SELF, change_runtime_dtype: bool = True) -> SELF: """ Converts the module's parameters to half-precision (float16). Optionally also changes the runtime computation dtype (`self.dtype`) to float16. Args: change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float16`. Defaults to True. Returns: SELF: The module instance with parameters (and potentially runtime dtype) set to float16. """ if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float16) return self._reformat_dtype(jnp.float16)
[docs] def float(self: SELF, change_runtime_dtype: bool = True) -> SELF: """ Converts the module's parameters to single-precision (float32). Optionally also changes the runtime computation dtype (`self.dtype`) to float32. Args: change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float32`. Defaults to True. Returns: SELF: The module instance with parameters (and potentially runtime dtype) set to float32. """ if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float32) return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: SELF, dtype) -> SELF: """ Internal helper to change the runtime computation data type (`dtype`) of the module and its submodules. Args: dtype (jnp.dtype): The target runtime data type. Returns: SELF: The module instance with updated runtime dtype. """ from easydel.utils.graph_utils import iter_module_search for path, module in iter_module_search(self): if hasattr(module, "dtype"): if str(type(module.dtype)).endswith( "lax_numpy._ScalarMeta'>" ): # dont change numpy based dtypes module.dtype = dtype self.dtype = dtype return self def _reformat_dtype(self: SELF, dtype) -> SELF: """ Internal helper to change the data type of the module's parameters (`param_dtype`). Casts floating-point parameters to the target `dtype`. Args: dtype (jnp.dtype): The target parameter data type. Returns: SELF: The module instance with updated parameter dtype. """ from easydel.utils.graph_utils import iter_module_search gdef, gtree, others = nn.split(self, nn.Param, ...) def _map(array): if array.dtype in [ jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.float_, ]: array = array.astype(dtype) return array gtree = jax.tree_util.tree_map(_map, gtree) self = nn.merge(gdef, gtree, others) for path, module in iter_module_search(self): if hasattr(module, "param_dtype"): if isinstance(module.param_dtype, jnp.dtype): module.param_dtype = dtype self.param_dtype = dtype return self def _match_partition_rules(self, partition_rules: tp.Any = None): """ Matches the provided or configured partition rules against the module's parameter shapes. Args: partition_rules (tp.Any, optional): The partition rules to use. If None, uses rules from the configuration. Defaults to None. Returns: tp.Any: The partition specifications matched to the parameter tree. """ return match_partition_rules( rules=self._get_partition_rules(partition_rules), tree=self.graphtree_params_shape, ) @property def _specs_sharding(self): """ Extracts the PartitionSpec part from the NamedSharding of each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing PartitionSpecs. """ def _map(array): if hasattr(array, "sharding"): sharding = array.sharding if isinstance(sharding, NamedSharding): return sharding.spec return PartitionSpec() return nn.from_tree( jax.tree_util.tree_map( _map, nn.to_tree(self), ) ) @property def _shardings(self): """ Extracts the sharding information (PartitionSpec or NamedSharding) for each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing the sharding info. """ return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(), nn.to_tree(self), ) ) @property def _named_shardings(self): """ Extracts the NamedSharding object (if present) for each parameter. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing NamedSharding or None. """ return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, nn.to_tree(self), ) ) def _get_mesh(self, mesh: tp.Optional[Mesh] = None) -> Mesh: """ Retrieves the JAX device mesh, prioritizing the provided argument over the configuration. Args: mesh (tp.Optional[Mesh]): A potential JAX device mesh. Returns: Mesh: The resolved JAX device mesh. Raises: ValueError: If no mesh is provided and none is found in the configuration. """ if mesh is None: if ( not hasattr(self, "config") or not hasattr(self.config, "mesh") or self.config.mesh is None ): raise ValueError( "A mesh must be provided, either as an argument or through the model config." ) return self.config.mesh return mesh def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike: """ Retrieves the partitioning rules, prioritizing the provided argument over the configuration. Args: partition_rules (PartitionLike): Potential partitioning rules. Returns: PartitionLike: The resolved partitioning rules. Raises: ValueError: If no rules are provided and none can be obtained from the configuration. """ if partition_rules is None: if not hasattr(self, "config"): raise ValueError( "Partition rules must be provided either as an argument or through the model config." ) return self.config.get_partition_rules(fully_sharded_data_parallel=True) return partition_rules def _apply_sharding_fns( self: SELF, sharding_fns: tp.Mapping[str, tp.Callable], ) -> SELF: """ Applies sharding or gathering functions to the module's parameters. Args: sharding_fns (tp.Mapping[str, tp.Callable]): A mapping from flattened parameter paths to sharding or gathering functions. Returns: SELF: The module instance with sharding/gathering functions applied to its parameters. """ gdef, state, others = nn.split(self, nn.Param, ...) sharding_fns = flatten_dict(sharding_fns) _shard_keys = list(sharding_fns.keys()) def _map(path, val: nn.VariableState): if val.value is not None and path in _shard_keys: try: val.value = sharding_fns[path](val.value) except TypeError: path = map(str, path) warnings.warn(f"couldn't shard/gather {'.'.join(path)}", stacklevel=1) return val state.update(state.map(_map)) self = nn.merge(gdef, state, others) return self
[docs] def shard_model( self: SELF, partition_rules: PartitionLike = None, mesh: tp.Optional[Mesh] = None, overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None, ) -> SELF: """ Shards the model's parameters according to the specified rules and mesh. Args: partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules. Defaults to None. mesh (tp.Optional[Mesh], optional): JAX device mesh. If None, uses config mesh. Defaults to None. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None. Returns: SELF: The sharded model instance. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules( rules=partition_rules, tree=self.graphtree_params_shape, ) shard_fns, _ = make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, ) if overlay_fns is not None: shard_fns.update(overlay_fns) self = self._apply_sharding_fns(shard_fns) return self
[docs] def gather_model( self: SELF, partition_rules: PartitionLike = None, mesh: tp.Optional[Mesh] = None, overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None, ) -> SELF: """ Gathers the model's parameters from potentially distributed devices to the host or a single device. Args: partition_rules (PartitionLike, optional): Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None. mesh (tp.Optional[Mesh], optional): JAX device mesh from which to gather. If None, uses config mesh. Defaults to None. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None. Returns: SELF: The model instance with gathered parameters. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules( rules=partition_rules, tree=self.graphtree_params_shape, ) _, gather_fns = make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, ) if overlay_fns is not None: gather_fns.update(overlay_fns) return self._apply_sharding_fns(gather_fns)
@property def _shard_fns(self): """ Generates the dictionary of sharding functions based on the module's configuration. Returns: tp.Mapping: A mapping from flattened parameter paths to sharding functions. """ mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[0] @property def _gather_fns(self): """ Generates the dictionary of gathering functions based on the module's configuration. Returns: tp.Mapping: A mapping from flattened parameter paths to gathering functions. """ mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[1]
[docs] def fully_shard(self: SELF, partition_rules: PartitionLike = None) -> SELF: """ Applies JAX sharding constraints to all parameters based on the partition rules. This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses `jax.jit` with `out_shardings` to enforce the constraints. Args: partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules. Defaults to None. Returns: SELF: The model instance with sharding constraints applied. """ class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=x), match_partition_rules( self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock) ), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def fully_gather(self: SELF) -> SELF: """ Applies JAX sharding constraints to gather all parameters onto the host or a single device. This function marks all parameters to have no sharding (PartitionSpec()). It uses `jax.jit` with `out_shardings` to enforce these gathering constraints. Returns: SELF: The model instance with gathering constraints applied. """ class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()), match_partition_rules( self._get_partition_rules(None), nn.eval_shape(lambda: mock) ), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def quantize( self: SELF, method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: tp.Optional[str] = None, quantize_tensors: bool = True, verbose: tp.Optional[bool] = None, ) -> SELF: """ Applies quantization to the module's linear layers or tensors. Args: method (EasyDeLQuantizationMethods, optional): The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT. block_size (int, optional): The block size for quantization methods that support it. Defaults to 128. quantization_pattern (tp.Optional[str], optional): A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None. quantize_tensors (bool, optional): If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs). verbose (tp.Optional[bool], optional): If True, logs information during the quantization process. Defaults to True only on process index 0. Returns: SELF: The quantized model instance. """ from easydel.layers.quantization.quantizers import EasyQuantizer quantizer = EasyQuantizer( quantization_method=method, block_size=block_size, quantization_pattern=quantization_pattern, ) if verbose is None: verbose = jax.process_index() == 0 if quantize_tensors: ... else: self = quantizer.quantize_linears( self, quantization_pattern=quantization_pattern, verbose=verbose, ) return self
[docs] def to_state(self) -> EasyDeLState: """ Converts the current module instance into an `EasyDeLState` object. This is useful for saving and managing the model's state, including parameters and potentially optimizer state (though optimizer state is typically added later). Returns: EasyDeLState: An EasyDeLState object representing the current model state. """ from easydel.infra.base_state import EasyDeLState return EasyDeLState.create(step=0, model=self)
[docs] def to_torch(self, **kwargs): """ Converts the EasyDeL module to its equivalent Hugging Face PyTorch model. Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format. Args: **kwargs: Additional keyword arguments passed to the parameter transformation function. Returns: torch.nn.Module: The equivalent Hugging Face PyTorch model with loaded weights. """ from easydel.utils.parameters_transformation import module_to_huggingface_model hf_autoloader = self.get_torch_loader() model_class = hf_autoloader._model_mapping[type(self.config)] hf_model = module_to_huggingface_model( module=self, base_huggingface_module=model_class, config=self.config, dtype=self.param_dtype, **kwargs, ) return hf_model
[docs] def prepare_inputs_for_call(self, **kwargs): """ Prepares keyword arguments before passing them to the module's `__call__` method. This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation). Args: **kwargs: The keyword arguments intended for `__call__`. Returns: dict: The prepared keyword arguments. """ return kwargs
[docs] def get_static_arguments(self) -> tp.Tuple: """ Returns a tuple of static arguments required by the module's `__call__` method. Static arguments are those that don't change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments. Returns: tp.Tuple: A tuple containing static arguments. """ return ()
[docs] @classmethod def lazy_init(cls: tp.Type[SELF], *args, **kwargs) -> SELF: """ Performs a "lazy" initialization using `nnx.eval_shape`. This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding. Args: *args: Positional arguments passed to the class constructor. **kwargs: Keyword arguments passed to the class constructor. Returns: SELF: A module instance with initialized structure but potentially abstract parameters. """ return nn.eval_shape(lambda: cls(*args, **kwargs))
[docs] def merge_lora_params(self: SELF, pytree: tp.Dict) -> SELF: """ Merges LoRA parameters from a pytree into the base model's parameters. Args: pytree (tp.Dict): A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model's parameters. Returns: SELF: The module instance with LoRA parameters merged into the base weights. """ from easydel.infra.utils import merge_lora_params self = merge_lora_params(self, pytree) return self
[docs] def split_lora_params(self: SELF) -> tp.Dict: """ Splits merged LoRA parameters back out from the base model's parameters. This function assumes LoRA parameters were previously merged using `merge_lora_params` or a similar process that stored the original base weights and LoRA weights appropriately. Returns: tp.Dict: A pytree containing the extracted LoRA parameters (A and B matrices). The base model parameters are restored to their original (pre-merge) state. """ from easydel.infra.utils import split_lora_params pytree = split_lora_params(self) return pytree
[docs] def apply_lora_to_layers( self: SELF, lora_rank: int, lora_pattern: tp.Optional[str] = None, verbose: bool = False, rngs: tp.Optional[nn.Rngs] = None, ) -> SELF: """ Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module. Replaces targeted `flax.linen.Dense` layers with `easydel.layers.lora.LoraLinear` layers, initializing the LoRA matrices (A and B). Args: lora_rank (int): The rank of the LoRA decomposition. lora_pattern (tp.Optional[str], optional): A regular expression to match the names of the `Dense` layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None. verbose (bool, optional): If True, prints information about which layers are being modified. Defaults to False. rngs (tp.Optional[nn.Rngs], optional): JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None. Returns: SELF: The module instance with LoRA layers applied. """ from easydel.infra.utils import apply_lora_to_layers self = apply_lora_to_layers( self, lora_pattern=lora_pattern, lora_rank=lora_rank, rngs=rngs, verbose=verbose, ) return self
[docs] def unwrap_lora_to_layers(self: SELF, verbose: bool = False) -> SELF: """ Reverts the application of LoRA layers, restoring the original linear layers. Replaces `easydel.layers.lora.LoraLinear` layers with their original `flax.linen.Dense` counterparts, discarding the LoRA matrices. Args: verbose (bool, optional): If True, prints information about which layers are being reverted. Defaults to False. Returns: SELF: The module instance with LoRA layers removed and original layers restored. """ from easydel.infra.utils import unwrap_lora_to_layers self = unwrap_lora_to_layers(self, verbose=verbose) return self
@property def transform_fn(self): """ Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters. This function identifies embedding and LayerNorm layers within the module and creates a transformation function (`torch_dict_to_easydel_params`) pre-configured with these layer names, the target parameter dtype, and the module's sharding functions. Returns: tp.Callable: A partial function ready to convert a PyTorch state dict. """ from easydel.utils import graph_utils from easydel.utils.parameters_transformation import torch_dict_to_easydel_params embedding_path = [ ".".join(tuple(map(str, pa))) for pa, _ in graph_utils.iter_module_search(self, nn.Embed) ] layernorm_path = [ ".".join(tuple(map(str, pa))) for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm) ] return partial( torch_dict_to_easydel_params, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, dtype=self.param_dtype, shard_fns=self._shard_fns, ) @property def _generate_compatible_graphdef(self): """ Creates a graph definition compatible with generation tasks. Often, generation requires specific configurations (like disabling gradient checkpointing). This method creates a temporary, generation-compatible configuration, performs a lazy initialization with it, and extracts the resulting graph definition. Returns: nn.GraphDef: A graph definition suitable for use during generation. """ from copy import deepcopy adjusted_config = deepcopy(self.config) adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE dummy = type(self).lazy_init( config=adjusted_config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) gdef, _, _ = nn.split(dummy, nn.Param, ...) return gdef @property def _generate_compatible_graphother(self): """ Creates the 'other' state (non-parameters) compatible with generation tasks. Similar to `_generate_compatible_graphdef`, this creates a temporary, generation-compatible configuration, lazy-initializes, and extracts the 'other' state variables, ensuring they have concrete values instead of meta-placeholders. Returns: nn.GraphState: A graph state containing non-parameter variables suitable for generation. """ from copy import deepcopy adjusted_config = deepcopy(self.config) adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE dummy = type(self).lazy_init( config=adjusted_config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, rngs=self.rngs, ) _, _, gother = nn.split(dummy, nn.Param, ...) gother = traversals.recreate_meta_values(gother) return gother @property def params_sharding(self) -> tp.Dict: """ Retrieves the sharding annotation for each parameter in the module. Returns: tp.Dict: A nested dictionary mirroring the parameter structure, containing the sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded. """ return jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, self.split_params_dict(), )
[docs] def merge_params(self, tree): """ Merges a given parameter state tree back into the module. Reconstructs the module using its existing graph definition and 'other' state, but replaces the parameter state with the provided `tree`. Args: tree: A pytree (likely a `nn.GraphState`) containing the parameters to merge. Returns: EasyDeLBaseModule: The module instance with the new parameters merged in. """ gdef, _, gother = nn.split(self, nn.Param, ...) self = nn.merge(gdef, tree, gother) return self
[docs] def split_params(self): """ Splits the module and returns the parameter state. Uses `nnx.split` to extract the `GraphState` containing the parameters. Returns: nn.GraphState: The parameter state of the module. """ return nn.split(self, nn.Param, ...)[1]
[docs] def split_params_dict( self, extract_fn: tp.Optional[tp.Callable] = None, remove_none: bool = True, ) -> tp.Dict: """ Splits the module parameters and returns them as a nested dictionary. Extracts the parameter state, converts it to a plain dictionary (removing `VariableState` wrappers), and optionally removes entries with `None` values. Args: extract_fn (tp.Optional[tp.Callable], optional): A function to apply to each parameter during extraction. Defaults to None. remove_none (bool, optional): If True, removes key-value pairs where the value is `None`. Defaults to True. Returns: tp.Dict: A nested dictionary containing the module's parameters. """ flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn)) if remove_none: flat_params = { k: v.value if hasattr(v, "value") else v for k, v in flat_params.items() if (v.value if hasattr(v, "value") else v) is not None } else: flat_params = { k: v.value if hasattr(v, "value") else v for k, v in flat_params.items() } return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: SELF, params_dict: tp.Dict) -> SELF: """ Merges parameters from a dictionary back into the module's state. Updates the module's current parameter state with values from the provided dictionary. Args: params_dict (tp.Dict): A nested dictionary containing the parameters to merge. The structure should match the module's parameter structure. Returns: SELF: The module instance with the parameters from the dictionary merged in. Raises: KeyError: If a key from `params_dict` is not found in the module's current state. """ current_state = self.split_params().flat_state() if not is_flatten(params_dict): params_dict = flatten_dict(params_dict) for key, value in params_dict.items(): if key in current_state: current_state[key].value = value else: raise KeyError(f"Parameter key {key} not found in the current model state.") self = self.merge_params(unflatten_dict(current_state)) return self
def _flop(self, *args, **kwargs) -> tp.Optional[float]: """ Estimates the FLOPs (Floating Point Operations) for a single forward pass (`__call__`). Uses JAX's `make_jaxpr` to get the computation graph and then analyzes it using `easydel.infra.utils.count_flop_jaxpr` to estimate FLOPs. Args: *args: Positional arguments to pass to `__call__`. **kwargs: Keyword arguments to pass to `__call__`. Returns: tp.Optional[float]: The estimated FLOP count, or None if calculation fails. """ from .utils import count_flop_jaxpr return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs)) @property def pure_transform_fn(self): """ Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters. Similar to `transform_fn`, but this version does *not* include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (`torch_dict_to_easydel_params`) configured only with layer names and dtype. Returns: tp.Callable: A partial function for converting a PyTorch state dict without applying sharding. """ from easydel.utils import graph_utils from easydel.utils.parameters_transformation import torch_dict_to_easydel_params embedding_path = [ ".".join(tuple(map(str, pa))) for pa, _ in graph_utils.iter_module_search(self, nn.Embed) ] layernorm_path = [ ".".join(tuple(map(str, pa))) for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm) ] return partial( torch_dict_to_easydel_params, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, dtype=self.param_dtype, ) @property def _default_loss_config(self) -> tp.Optional[LossConfig]: """ Provides a default LossConfig for the module, if applicable. Subclasses can override this property to return a default `LossConfig` instance specific to the model's task (e.g., setting `num_labels` for sequence classification). Returns: tp.Optional[LossConfig]: The default loss configuration, or None. """ return None @_default_loss_config.setter def _default_loss_config(self, val): """Setter for the default loss config (internal use).""" return val
[docs] def compute_loss( self, *, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch, ) -> tp.Tuple[tp.Any, LossMetrics]: """ Computes the loss for the model given a batch of inputs and labels. This method performs a forward pass using the provided `batch` arguments, then calculates the loss using the determined `loss_function`. It handles potential label inference (e.g., using `input_ids` as labels for Causal LM) and default loss configurations. Args: labels (tp.Optional[chex.Array], optional): The target labels. If None and the task is Causal LM, `input_ids` from the batch might be used. Defaults to None. loss_config (tp.Optional[LossConfig], optional): Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None. loss_kwargs (tp.Optional[tp.Dict], optional): Additional keyword arguments to pass directly to the loss function. Defaults to None. **batch: Keyword arguments representing the input batch (e.g., `input_ids`, `attention_mask`). Returns: tp.Tuple[tp.Any, LossMetrics]: A tuple containing: - The model's output ( Pytree typically including logits, hidden states etc.) - A `LossMetrics` object containing the calculated loss and potentially other metrics. Raises: AssertionError: If labels are required for the loss function but are not provided or inferred. AssertionError: If sequence classification loss is used without `num_labels` in the config. """ if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__: labels = batch.get("input_ids", None) if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__: if loss_config is None: assert hasattr(self.config, "num_labels"), ( "in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`" ) loss_config = LossConfig(num_labels=self.config.num_labels) assert labels is not None, "`labels` can not be `None` for computing loss." loss_kwargs = loss_kwargs or {} batch.pop("return_dict", None) outputs = self(**batch, return_dict=True) loss_output: LossMetrics = self.loss_function( labels=labels, config=loss_config, paxis=self.config.partition_axis, **loss_kwargs, **outputs, **batch, ) if hasattr(outputs, "aux_loss"): if outputs.aux_loss is not None: loss_output.loss = loss_output.loss + outputs.aux_loss outputs = outputs.replace(loss=loss_output.loss) return outputs, loss_output