# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
import typing as tp
import warnings
from functools import cached_property, partial
import chex
import flax
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from easydel.utils import traversals
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict
from .base_config import EasyDeLBaseConfig
from .etils import EasyDeLGradientCheckPointers, EasyDeLQuantizationMethods
from .loss_utils import (
LOSS_MAPPING,
ForCausalLMLoss,
ForSequenceClassificationLoss,
LossConfig,
LossMetrics,
)
from .mixins import (
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
)
if tp.TYPE_CHECKING:
from easydel.infra.base_state import EasyDeLState
else:
EasyDeLState = tp.Any
PartitionLike = tp.Optional[
tp.Union[
tp.Mapping[str, tp.Callable],
tp.Mapping[tuple, tp.Callable],
]
]
logger = get_logger(__name__)
_CP = tp.TypeVar("CP")
SELF = tp.TypeVar("SELF")
[docs]class EasyDeLBaseModule(
nn.Module,
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
):
"""
Base class for EasyDeL modules, providing common functionalities for model initialization,
parameter handling, and integration with the EasyDeL ecosystem.
"""
config_class: tp.Type[EasyDeLBaseConfig]
base_model_prefix: str
_model_task: tp.Optional[str] = None
_model_type: tp.Optional[str] = None
def __init__(
self,
config: tp.Union[EasyDeLBaseConfig, _CP],
dtype: jnp.dtype,
param_dtype: jnp.dtype,
precision: lax.PrecisionLike,
rngs: nn.Rngs,
):
"""Initializes the EasyDeLBaseModule.
Args:
config (EasyDeLBaseConfig): The model configuration.
dtype (jnp.dtype): The data type for computation.
param_dtype (jnp.dtype): The data type for parameters.
precision (jax.lax.PrecisionLike): The numerical precision.
rngs (nn.Rngs): The random number generators.
"""
self.config: tp.Union[EasyDeLBaseConfig, _CP] = config
self.dtype: jnp.dtype = dtype
self.param_dtype: jnp.dtype = param_dtype
self.precision: lax.PrecisionLike = precision
self.rngs: nn.Rngs = rngs
# these useless call's are just here to init values in graphdef
_ = self.graphtree_shape
_ = self.graphtree_params_shape
_ = self.mesh
_ = self.model_task
_ = self.model_type
@property
def parameters(self) -> tp.Dict:
"""
Retrieves the parameters of the module as a dictionary.
This property iterates through the module and its submodules, extracting
variables marked as `nn.Param` and returning them in a flat dictionary
where keys represent the parameter path.
Returns:
tp.Dict: A dictionary containing the module's parameters.
"""
from easydel.utils.graph_utils import iter_module_search
parameters = {}
for key, value in iter_module_search(self, nn.Param):
parameters[key] = value.value
return parameters
[docs] def split_module(self):
return nn.split(self, nn.Param, ...)
[docs] @staticmethod
def merge_module(
graphdef: nn.GraphDef,
graphstate: nn.GraphState,
graphother: nn.GraphState,
):
return nn.merge(graphdef, graphstate, graphother)
@property
def graphdef(self) -> nn.GraphDef:
"""
Returns the graph definition (structure without parameters) of the module.
Uses `flax.nnx.split` to separate the graph definition from the state (parameters).
Returns:
nn.GraphDef: The graph definition of the module.
"""
return nn.split(self, nn.Param, ...)[0]
@property
def graphstate(self) -> nn.GraphState:
"""
Returns the graph state (parameters) of the module.
Uses `flax.nnx.split` to separate the state (parameters) from the graph definition.
Returns:
nn.GraphState: The graph state containing the module's parameters.
"""
return nn.split(self, nn.Param, ...)[1]
@property
def graphother(self) -> nn.GraphState:
"""
Returns any other state variables in the module (non-parameters).
Uses `flax.nnx.split` to separate non-parameter state variables.
Returns:
nn.GraphState: The graph state containing non-parameter variables.
"""
return nn.split(self, nn.Param, ...)[-1]
@property
def graphtree_params_shape(self) -> tp.Dict:
"""
Computes and returns the shapes of the module's parameters as a nested dictionary.
It uses `nnx.eval_shape` to determine the shapes without actual computation,
then extracts the shape information from the resulting graph state.
Returns:
tp.Dict: A nested dictionary mirroring the parameter structure, containing their shapes.
"""
graphtree = nn.eval_shape(lambda: nn.split(self, nn.Param, ...)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def graphtree_shape(self) -> tp.Dict:
"""
Computes and returns the shapes of all state variables (including non-parameters) in the module.
Uses `nnx.eval_shape` on the entire module state (parameters and others)
and extracts the shape information.
Returns:
tp.Dict: A nested dictionary mirroring the module's state structure, containing the shapes.
"""
graphtree = nn.eval_shape(lambda: nn.split(self)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def mesh(self) -> jax.sharding.Mesh:
"""
Retrieves the JAX device mesh from the module's configuration.
Returns:
jax.sharding.Mesh: The device mesh defined in `self.config.mesh`.
"""
return self.config.mesh
@property
def model_task(self) -> tp.Optional[str]:
"""
Returns the specific task associated with this model instance (e.g., 'causal-language-model').
Returns:
tp.Optional[str]: The model task identifier, or None if not set.
"""
return self._model_task
@property
def model_type(self) -> tp.Optional[str]:
"""
Returns the specific type of this model instance (e.g., 'llama', 'mistral').
Returns:
tp.Optional[str]: The model type identifier, or None if not set.
"""
return self._model_type
@property
def params(self) -> tp.Dict:
"""
Returns the parameters and other state variables of the module as a dictionary.
Uses `flax.nnx.split` to get the combined state (parameters and others).
Returns:
tp.Dict: A dictionary containing all state variables of the module.
"""
return nn.split(self)[-1]
@cached_property
def causal_mask(self) -> jnp.ndarray:
"""
Retrieves or computes the basic causal attention mask from the configuration.
Uses `self.config.get_basic_causal_mask()` and caches the result.
Returns:
jnp.ndarray: The causal attention mask, potentially cached.
"""
return self.config.get_basic_causal_mask()
@cached_property
def frequencies(self) -> jnp.ndarray:
"""
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses `self.config.get_basic_frequencies()` and caches the result.
Returns:
jnp.ndarray: The frequency components, potentially cached.
"""
return self.config.get_basic_frequencies()
@cached_property
def inv_frequencies(self) -> jnp.ndarray:
"""
Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.
Uses `self.config.get_basic_inv_frequencies()` and caches the result.
Returns:
jnp.ndarray: The inv-frequency components, potentially cached.
"""
return self.config.get_basic_inv_frequencies()
@cached_property
def static_arguments(self) -> tp.Tuple:
"""
Retrieves or computes static arguments needed for the module's `__call__` method.
Uses `self.get_static_arguments()` and caches the result. Static arguments
are typically those that don't change during execution and can be pre-computed.
Returns:
tp.Tuple: A tuple of static arguments.
"""
return self.get_static_arguments()
@cached_property
def loss_function(self):
"""
Determines and returns the appropriate loss function based on the configuration or model type.
It prioritizes `config.loss_type`, then `self.loss_type`, and finally tries to infer
the loss type from the class name. If no suitable loss function is found, it defaults
to `ForCausalLMLoss` and issues a warning.
Returns:
tp.Callable: The selected loss function (e.g., `ForCausalLMLoss`, `ForSequenceClassificationLoss`).
"""
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
elif getattr(self, "loss_type", None) is not None:
loss_type = self.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if (
loss_type is None
or loss_type not in LOSS_MAPPING
and getattr(self.config, "loss_type", None) is not None
):
warnings.warn(
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`.",
stacklevel=1,
)
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
@property
def module_dtype(self) -> jnp.dtype:
"""
Determines the data type of the module's parameters.
It inspects the flattened parameter state to find the dtype of the first parameter encountered.
Returns:
jnp.dtype: The data type of the module's parameters.
"""
params_state = nn.split(self, nn.Param, ...)[1].flat_state()
return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def compute_complex_rotary(self, position_ids: jax.Array) -> jnp.ndarray:
frequencies = jnp.transpose(
self.inv_frequencies[None, :, None] @ position_ids[:, None, :].astype("f4"),
(0, 2, 1),
)
return jnp.exp(1j * frequencies)
[docs] def to_dtype(self: SELF, dtype: jnp.dtype) -> SELF:
"""
Converts the module's parameters to the specified data type.
It iterates through the module's parameters (excluding quantization-related ones)
and casts them to the target `dtype`. It also updates the `param_dtype` attribute
of the module and its submodules if they exist.
Args:
dtype (jnp.dtype): The target data type for the parameters.
Returns:
SELF: The module instance with parameters converted to the specified dtype.
"""
from easydel.utils.graph_utils import iter_module_search
gdef, state, others = nn.split(self, nn.Param, ...)
def _map(path, val: nn.VariableState):
if val.value is not None:
if not path[-1].startswith("quant_"):
val.value = val.value.astype(dtype)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
module.param_dtype = dtype
return self
[docs] def half(self: SELF, change_runtime_dtype: bool = True) -> SELF:
"""
Converts the module's parameters to half-precision (float16).
Optionally also changes the runtime computation dtype (`self.dtype`) to float16.
Args:
change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float16`.
Defaults to True.
Returns:
SELF: The module instance with parameters (and potentially runtime dtype) set to float16.
"""
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float16)
return self._reformat_dtype(jnp.float16)
[docs] def float(self: SELF, change_runtime_dtype: bool = True) -> SELF:
"""
Converts the module's parameters to single-precision (float32).
Optionally also changes the runtime computation dtype (`self.dtype`) to float32.
Args:
change_runtime_dtype (bool): If True, also sets `self.dtype` to `jnp.float32`.
Defaults to True.
Returns:
SELF: The module instance with parameters (and potentially runtime dtype) set to float32.
"""
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float32)
return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: SELF, dtype) -> SELF:
"""
Internal helper to change the runtime computation data type (`dtype`) of the module and its submodules.
Args:
dtype (jnp.dtype): The target runtime data type.
Returns:
SELF: The module instance with updated runtime dtype.
"""
from easydel.utils.graph_utils import iter_module_search
for path, module in iter_module_search(self):
if hasattr(module, "dtype"):
if str(type(module.dtype)).endswith(
"lax_numpy._ScalarMeta'>"
): # dont change numpy based dtypes
module.dtype = dtype
self.dtype = dtype
return self
def _reformat_dtype(self: SELF, dtype) -> SELF:
"""
Internal helper to change the data type of the module's parameters (`param_dtype`).
Casts floating-point parameters to the target `dtype`.
Args:
dtype (jnp.dtype): The target parameter data type.
Returns:
SELF: The module instance with updated parameter dtype.
"""
from easydel.utils.graph_utils import iter_module_search
gdef, gtree, others = nn.split(self, nn.Param, ...)
def _map(array):
if array.dtype in [
jnp.bfloat16,
jnp.float16,
jnp.float32,
jnp.float64,
jnp.float_,
]:
array = array.astype(dtype)
return array
gtree = jax.tree_util.tree_map(_map, gtree)
self = nn.merge(gdef, gtree, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
if isinstance(module.param_dtype, jnp.dtype):
module.param_dtype = dtype
self.param_dtype = dtype
return self
def _match_partition_rules(self, partition_rules: tp.Any = None):
"""
Matches the provided or configured partition rules against the module's parameter shapes.
Args:
partition_rules (tp.Any, optional): The partition rules to use. If None, uses rules
from the configuration. Defaults to None.
Returns:
tp.Any: The partition specifications matched to the parameter tree.
"""
return match_partition_rules(
rules=self._get_partition_rules(partition_rules),
tree=self.graphtree_params_shape,
)
@property
def _specs_sharding(self):
"""
Extracts the PartitionSpec part from the NamedSharding of each parameter.
Returns:
tp.Dict: A nested dictionary mirroring the parameter structure, containing PartitionSpecs.
"""
def _map(array):
if hasattr(array, "sharding"):
sharding = array.sharding
if isinstance(sharding, NamedSharding):
return sharding.spec
return PartitionSpec()
return nn.from_tree(
jax.tree_util.tree_map(
_map,
nn.to_tree(self),
)
)
@property
def _shardings(self):
"""
Extracts the sharding information (PartitionSpec or NamedSharding) for each parameter.
Returns:
tp.Dict: A nested dictionary mirroring the parameter structure, containing the sharding info.
"""
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(),
nn.to_tree(self),
)
)
@property
def _named_shardings(self):
"""
Extracts the NamedSharding object (if present) for each parameter.
Returns:
tp.Dict: A nested dictionary mirroring the parameter structure, containing NamedSharding or None.
"""
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
nn.to_tree(self),
)
)
def _get_mesh(self, mesh: tp.Optional[Mesh] = None) -> Mesh:
"""
Retrieves the JAX device mesh, prioritizing the provided argument over the configuration.
Args:
mesh (tp.Optional[Mesh]): A potential JAX device mesh.
Returns:
Mesh: The resolved JAX device mesh.
Raises:
ValueError: If no mesh is provided and none is found in the configuration.
"""
if mesh is None:
if (
not hasattr(self, "config")
or not hasattr(self.config, "mesh")
or self.config.mesh is None
):
raise ValueError(
"A mesh must be provided, either as an argument or through the model config."
)
return self.config.mesh
return mesh
def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike:
"""
Retrieves the partitioning rules, prioritizing the provided argument over the configuration.
Args:
partition_rules (PartitionLike): Potential partitioning rules.
Returns:
PartitionLike: The resolved partitioning rules.
Raises:
ValueError: If no rules are provided and none can be obtained from the configuration.
"""
if partition_rules is None:
if not hasattr(self, "config"):
raise ValueError(
"Partition rules must be provided either as an argument or through the model config."
)
return self.config.get_partition_rules(fully_sharded_data_parallel=True)
return partition_rules
def _apply_sharding_fns(
self: SELF,
sharding_fns: tp.Mapping[str, tp.Callable],
) -> SELF:
"""
Applies sharding or gathering functions to the module's parameters.
Args:
sharding_fns (tp.Mapping[str, tp.Callable]): A mapping from flattened parameter paths
to sharding or gathering functions.
Returns:
SELF: The module instance with sharding/gathering functions applied to its parameters.
"""
gdef, state, others = nn.split(self, nn.Param, ...)
sharding_fns = flatten_dict(sharding_fns)
_shard_keys = list(sharding_fns.keys())
def _map(path, val: nn.VariableState):
if val.value is not None and path in _shard_keys:
try:
val.value = sharding_fns[path](val.value)
except TypeError:
path = map(str, path)
warnings.warn(f"couldn't shard/gather {'.'.join(path)}", stacklevel=1)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
return self
[docs] def shard_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""
Shards the model's parameters according to the specified rules and mesh.
Args:
partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules.
Defaults to None.
mesh (tp.Optional[Mesh], optional): JAX device mesh. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply,
potentially overriding default sharding for specific parameters. Defaults to None.
Returns:
SELF: The sharded model instance.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
shard_fns, _ = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
shard_fns.update(overlay_fns)
self = self._apply_sharding_fns(shard_fns)
return self
[docs] def gather_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""
Gathers the model's parameters from potentially distributed devices to the host or a single device.
Args:
partition_rules (PartitionLike, optional): Partitioning rules used to determine how parameters
were originally sharded. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional): JAX device mesh from which to gather. If None, uses config mesh.
Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional): Additional functions to apply,
potentially overriding default gathering for specific parameters. Defaults to None.
Returns:
SELF: The model instance with gathered parameters.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
_, gather_fns = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
gather_fns.update(overlay_fns)
return self._apply_sharding_fns(gather_fns)
@property
def _shard_fns(self):
"""
Generates the dictionary of sharding functions based on the module's configuration.
Returns:
tp.Mapping: A mapping from flattened parameter paths to sharding functions.
"""
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[0]
@property
def _gather_fns(self):
"""
Generates the dictionary of gathering functions based on the module's configuration.
Returns:
tp.Mapping: A mapping from flattened parameter paths to gathering functions.
"""
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[1]
[docs] def apply_out_shardings(self, out_shardings):
splits = self.split_module()
@partial(jax.jit, out_shardings=out_shardings)
def _call(graphstate, graphother):
return graphstate, graphother
splits[1:] = _call(*splits[1:])
return self.merge_module(*splits)
[docs] def fully_shard(self: SELF, partition_rules: PartitionLike = None) -> SELF:
"""
Applies JAX sharding constraints to all parameters based on the partition rules.
This function ensures that parameters are explicitly marked with their intended sharding,
which can be useful for performance and correctness checks. It uses `jax.jit` with
`out_shardings` to enforce the constraints.
Args:
partition_rules (PartitionLike, optional): Partitioning rules. If None, uses config rules.
Defaults to None.
Returns:
SELF: The model instance with sharding constraints applied.
"""
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=x),
match_partition_rules(
self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def fully_gather(self: SELF) -> SELF:
"""
Applies JAX sharding constraints to gather all parameters onto the host or a single device.
This function marks all parameters to have no sharding (PartitionSpec()). It uses `jax.jit`
with `out_shardings` to enforce these gathering constraints.
Returns:
SELF: The model instance with gathering constraints applied.
"""
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()),
match_partition_rules(
self._get_partition_rules(None), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def quantize(
self: SELF,
method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT,
block_size: int = 128,
quantization_pattern: tp.Optional[str] = None,
quantize_tensors: bool = True,
verbose: tp.Optional[bool] = None,
) -> SELF:
"""
Applies quantization to the module's linear layers or tensors.
Args:
method (EasyDeLQuantizationMethods, optional): The quantization algorithm to use
(e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.
block_size (int, optional): The block size for quantization methods that support it.
Defaults to 128.
quantization_pattern (tp.Optional[str], optional): A regular expression to match
parameter names that should be quantized. If None, uses a default pattern.
Defaults to None.
quantize_tensors (bool, optional): If True, quantizes the tensor values directly.
If False (currently default behavior in implementation), replaces Linear layers
with their quantized equivalents. Defaults to True (though implementation differs).
verbose (tp.Optional[bool], optional): If True, logs information during the quantization process.
Defaults to True only on process index 0.
Returns:
SELF: The quantized model instance.
"""
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = EasyQuantizer(
quantization_method=method,
block_size=block_size,
quantization_pattern=quantization_pattern,
)
if verbose is None:
verbose = jax.process_index() == 0
if quantize_tensors:
...
else:
self = quantizer.quantize_linears(
self,
quantization_pattern=quantization_pattern,
verbose=verbose,
)
return self
[docs] def to_state(self) -> EasyDeLState:
"""
Converts the current module instance into an `EasyDeLState` object.
This is useful for saving and managing the model's state, including parameters
and potentially optimizer state (though optimizer state is typically added later).
Returns:
EasyDeLState: An EasyDeLState object representing the current model state.
"""
from easydel.infra.base_state import EasyDeLState
return EasyDeLState.create(step=0, model=self)
[docs] def to_torch(self, **kwargs):
"""
Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.
Requires the corresponding PyTorch model class to be available and registered.
Uses utility functions to transfer parameters from JAX to PyTorch format.
Args:
**kwargs: Additional keyword arguments passed to the parameter transformation function.
Returns:
torch.nn.Module: The equivalent Hugging Face PyTorch model with loaded weights.
"""
from easydel.utils.parameters_transformation import module_to_huggingface_model
hf_autoloader = self.get_torch_loader()
model_class = hf_autoloader._model_mapping[type(self.config)]
hf_model = module_to_huggingface_model(
module=self,
base_huggingface_module=model_class,
config=self.config,
dtype=self.param_dtype,
**kwargs,
)
return hf_model
[docs] def get_static_arguments(self) -> tp.Tuple:
"""
Returns a tuple of static arguments required by the module's `__call__` method.
Static arguments are those that don't change across calls and can be potentially
cached or handled differently by JIT compilation. This base implementation returns
an empty tuple. Subclasses should override this if they have static arguments.
Returns:
tp.Tuple: A tuple containing static arguments.
"""
return ()
[docs] @classmethod
def lazy_init(cls: tp.Type[SELF], *args, **kwargs) -> SELF:
"""
Performs a "lazy" initialization using `nnx.eval_shape`.
This initializes the module structure and determines parameter shapes without
actually allocating memory for the parameters. Useful for inspecting the model
structure or preparing for sharding.
Args:
*args: Positional arguments passed to the class constructor.
**kwargs: Keyword arguments passed to the class constructor.
Returns:
SELF: A module instance with initialized structure but potentially abstract parameters.
"""
return nn.eval_shape(lambda: cls(*args, **kwargs))
[docs] def merge_lora_params(self: SELF, pytree: tp.Dict) -> SELF:
"""
Merges LoRA parameters from a pytree into the base model's parameters.
Args:
pytree (tp.Dict): A dictionary (pytree) containing the LoRA parameters (A and B matrices)
structured similarly to the base model's parameters.
Returns:
SELF: The module instance with LoRA parameters merged into the base weights.
"""
from easydel.infra.utils import merge_lora_params
self = merge_lora_params(self, pytree)
return self
[docs] def split_lora_params(self: SELF) -> tp.Dict:
"""
Splits merged LoRA parameters back out from the base model's parameters.
This function assumes LoRA parameters were previously merged using `merge_lora_params`
or a similar process that stored the original base weights and LoRA weights appropriately.
Returns:
tp.Dict: A pytree containing the extracted LoRA parameters (A and B matrices).
The base model parameters are restored to their original (pre-merge) state.
"""
from easydel.infra.utils import split_lora_params
pytree = split_lora_params(self)
return pytree
[docs] def apply_lora_to_layers(
self: SELF,
lora_rank: int,
lora_pattern: tp.Optional[str] = None,
verbose: bool = False,
rngs: tp.Optional[nn.Rngs] = None,
) -> SELF:
"""
Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.
Replaces targeted `flax.linen.Dense` layers with `easydel.layers.lora.LoraLinear`
layers, initializing the LoRA matrices (A and B).
Args:
lora_rank (int): The rank of the LoRA decomposition.
lora_pattern (tp.Optional[str], optional): A regular expression to match the names
of the `Dense` layers to apply LoRA to. If None, applies to common attention
and MLP layers. Defaults to None.
verbose (bool, optional): If True, prints information about which layers are being
modified. Defaults to False.
rngs (tp.Optional[nn.Rngs], optional): JAX random number generators for initializing
LoRA matrices. If None, default RNGs might be used. Defaults to None.
Returns:
SELF: The module instance with LoRA layers applied.
"""
from easydel.infra.utils import apply_lora_to_layers
self = apply_lora_to_layers(
self,
lora_pattern=lora_pattern,
lora_rank=lora_rank,
rngs=rngs,
verbose=verbose,
)
return self
[docs] def unwrap_lora_to_layers(self: SELF, verbose: bool = False) -> SELF:
"""
Reverts the application of LoRA layers, restoring the original linear layers.
Replaces `easydel.layers.lora.LoraLinear` layers with their original `flax.linen.Dense`
counterparts, discarding the LoRA matrices.
Args:
verbose (bool, optional): If True, prints information about which layers are being
reverted. Defaults to False.
Returns:
SELF: The module instance with LoRA layers removed and original layers restored.
"""
from easydel.infra.utils import unwrap_lora_to_layers
self = unwrap_lora_to_layers(self, verbose=verbose)
return self
@property
def transform_fn(self):
"""
Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.
This function identifies embedding and LayerNorm layers within the module and creates
a transformation function (`torch_dict_to_easydel_params`) pre-configured with these
layer names, the target parameter dtype, and the module's sharding functions.
Returns:
tp.Callable: A partial function ready to convert a PyTorch state dict.
"""
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
]
layernorm_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
shard_fns=self._shard_fns,
)
@property
def _generate_compatible_graphdef(self):
"""
Creates a graph definition compatible with generation tasks.
Often, generation requires specific configurations (like disabling gradient checkpointing).
This method creates a temporary, generation-compatible configuration, performs a lazy
initialization with it, and extracts the resulting graph definition.
Returns:
nn.GraphDef: A graph definition suitable for use during generation.
"""
from copy import deepcopy
adjusted_config = deepcopy(self.config)
adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE
dummy = type(self).lazy_init(
config=adjusted_config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
rngs=self.rngs,
)
gdef, _, _ = nn.split(dummy, nn.Param, ...)
return gdef
@property
def _generate_compatible_graphother(self):
"""
Creates the 'other' state (non-parameters) compatible with generation tasks.
Similar to `_generate_compatible_graphdef`, this creates a temporary,
generation-compatible configuration, lazy-initializes, and extracts the 'other'
state variables, ensuring they have concrete values instead of meta-placeholders.
Returns:
nn.GraphState: A graph state containing non-parameter variables suitable for generation.
"""
from copy import deepcopy
adjusted_config = deepcopy(self.config)
adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE
dummy = type(self).lazy_init(
config=adjusted_config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
rngs=self.rngs,
)
_, _, gother = nn.split(dummy, nn.Param, ...)
gother = traversals.recreate_meta_values(gother)
return gother
@property
def params_sharding(self) -> tp.Dict:
"""
Retrieves the sharding annotation for each parameter in the module.
Returns:
tp.Dict: A nested dictionary mirroring the parameter structure, containing the
sharding information (e.g., NamedSharding, PartitionSpec) for each parameter,
or None if unsharded.
"""
return jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
self.split_params_dict(),
)
[docs] def merge_params(self, tree):
"""
Merges a given parameter state tree back into the module.
Reconstructs the module using its existing graph definition and 'other' state,
but replaces the parameter state with the provided `tree`.
Args:
tree: A pytree (likely a `nn.GraphState`) containing the parameters to merge.
Returns:
EasyDeLBaseModule: The module instance with the new parameters merged in.
"""
gdef, _, gother = nn.split(self, nn.Param, ...)
self = nn.merge(gdef, tree, gother)
return self
[docs] def split_params(self):
"""
Splits the module and returns the parameter state.
Uses `nnx.split` to extract the `GraphState` containing the parameters.
Returns:
nn.GraphState: The parameter state of the module.
"""
return nn.split(self, nn.Param, ...)[1]
[docs] def split_params_dict(
self,
extract_fn: tp.Optional[tp.Callable] = None,
remove_none: bool = True,
) -> tp.Dict:
"""
Splits the module parameters and returns them as a nested dictionary.
Extracts the parameter state, converts it to a plain dictionary (removing `VariableState`
wrappers), and optionally removes entries with `None` values.
Args:
extract_fn (tp.Optional[tp.Callable], optional): A function to apply to each parameter
during extraction. Defaults to None.
remove_none (bool, optional): If True, removes key-value pairs where the value is `None`.
Defaults to True.
Returns:
tp.Dict: A nested dictionary containing the module's parameters.
"""
flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn))
if remove_none:
flat_params = {
k: v.value if hasattr(v, "value") else v
for k, v in flat_params.items()
if (v.value if hasattr(v, "value") else v) is not None
}
else:
flat_params = {
k: v.value if hasattr(v, "value") else v for k, v in flat_params.items()
}
return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: SELF, params_dict: tp.Dict) -> SELF:
"""
Merges parameters from a dictionary back into the module's state.
Updates the module's current parameter state with values from the provided dictionary.
Args:
params_dict (tp.Dict): A nested dictionary containing the parameters to merge.
The structure should match the module's parameter structure.
Returns:
SELF: The module instance with the parameters from the dictionary merged in.
Raises:
KeyError: If a key from `params_dict` is not found in the module's current state.
"""
current_state = self.split_params().flat_state()
if not is_flatten(params_dict):
params_dict = flatten_dict(params_dict)
for key, value in params_dict.items():
if key in current_state:
current_state[key].value = value
else:
raise KeyError(f"Parameter key {key} not found in the current model state.")
self = self.merge_params(unflatten_dict(current_state))
return self
def _flop(self, *args, **kwargs) -> tp.Optional[float]:
"""
Estimates the FLOPs (Floating Point Operations) for a single forward pass (`__call__`).
Uses JAX's `make_jaxpr` to get the computation graph and then analyzes it
using `easydel.infra.utils.count_flop_jaxpr` to estimate FLOPs.
Args:
*args: Positional arguments to pass to `__call__`.
**kwargs: Keyword arguments to pass to `__call__`.
Returns:
tp.Optional[float]: The estimated FLOP count, or None if calculation fails.
"""
from .utils import count_flop_jaxpr
return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs))
@property
def pure_transform_fn(self):
"""
Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.
Similar to `transform_fn`, but this version does *not* include sharding functions.
It identifies embedding and LayerNorm layers and returns a partial function
(`torch_dict_to_easydel_params`) configured only with layer names and dtype.
Returns:
tp.Callable: A partial function for converting a PyTorch state dict without applying sharding.
"""
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
]
layernorm_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
)
@property
def _default_loss_config(self) -> tp.Optional[LossConfig]:
"""
Provides a default LossConfig for the module, if applicable.
Subclasses can override this property to return a default `LossConfig`
instance specific to the model's task (e.g., setting `num_labels` for
sequence classification).
Returns:
tp.Optional[LossConfig]: The default loss configuration, or None.
"""
return None
@_default_loss_config.setter
def _default_loss_config(self, val):
"""Setter for the default loss config (internal use)."""
return val
[docs] def compute_loss(
self,
*,
labels: tp.Optional[chex.Array] = None,
loss_config: tp.Optional[LossConfig] = None,
loss_kwargs: tp.Optional[tp.Dict] = None,
**batch,
) -> tp.Tuple[tp.Any, LossMetrics]:
"""
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided `batch` arguments,
then calculates the loss using the determined `loss_function`. It handles
potential label inference (e.g., using `input_ids` as labels for Causal LM)
and default loss configurations.
Args:
labels (tp.Optional[chex.Array], optional): The target labels. If None and the task is Causal LM,
`input_ids` from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional): Specific configuration for the loss calculation.
If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional): Additional keyword arguments to pass directly
to the loss function. Defaults to None.
**batch: Keyword arguments representing the input batch (e.g., `input_ids`, `attention_mask`).
Returns:
tp.Tuple[tp.Any, LossMetrics]: A tuple containing:
- The model's output ( Pytree typically including logits, hidden states etc.)
- A `LossMetrics` object containing the calculated loss and potentially other metrics.
Raises:
AssertionError: If labels are required for the loss function but are not provided or inferred.
AssertionError: If sequence classification loss is used without `num_labels` in the config.
"""
if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__:
labels = batch.get("input_ids", None)
if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__:
if loss_config is None:
assert hasattr(self.config, "num_labels"), (
"in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`"
)
loss_config = LossConfig(num_labels=self.config.num_labels)
assert labels is not None, "`labels` can not be `None` for computing loss."
loss_kwargs = loss_kwargs or {}
outputs = self(**batch)
loss_output: LossMetrics = self.loss_function(
labels=labels,
config=loss_config,
paxis=self.config.partition_axis,
**loss_kwargs,
**outputs,
**batch,
)
if hasattr(outputs, "aux_loss"):
if outputs.aux_loss is not None:
loss_output.loss = loss_output.loss + outputs.aux_loss
outputs = outputs.replace(loss=loss_output.loss)
return outputs, loss_output