# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
import typing as tp
import warnings
from functools import cached_property, partial
import chex
import flax
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from easydel.utils import traversals
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict
from .base_config import EasyDeLBaseConfig
from .etils import EasyDeLGradientCheckPointers, EasyDeLQuantizationMethods
from .loss_utils import (
LOSS_MAPPING,
ForCausalLMLoss,
ForSequenceClassificationLoss,
LossConfig,
LossMetrics,
)
from .mixins import (
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
)
if tp.TYPE_CHECKING:
from easydel.infra.base_state import EasyDeLState
else:
EasyDeLState = tp.Any
PartitionLike = tp.Optional[
tp.Union[
tp.Mapping[str, tp.Callable],
tp.Mapping[tuple, tp.Callable],
]
]
logger = get_logger(__name__)
_CP = tp.TypeVar("CP")
SELF = tp.TypeVar("SELF")
[docs]class EasyDeLBaseModule(
nn.Module,
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
):
"""
Base class for EasyDeL modules, providing common functionalities for model initialization,
parameter handling, and integration with the EasyDeL ecosystem.
"""
config_class: tp.Type[EasyDeLBaseConfig]
base_model_prefix: str
_model_task: tp.Optional[str] = None
_model_type: tp.Optional[str] = None
def __init__(
self,
config: tp.Union[EasyDeLBaseConfig, _CP],
dtype: jnp.dtype,
param_dtype: jnp.dtype,
precision: lax.PrecisionLike,
rngs: nn.Rngs,
):
"""Initializes the EasyDeLBaseModule.
Args:
config (EasyDeLBaseConfig): The model configuration.
dtype (jnp.dtype): The data type for computation.
param_dtype (jnp.dtype): The data type for parameters.
precision (jax.lax.PrecisionLike): The numerical precision.
rngs (nn.Rngs): The random number generators.
"""
self.config: tp.Union[EasyDeLBaseConfig, _CP] = config
self.dtype: jnp.dtype = dtype
self.param_dtype: jnp.dtype = param_dtype
self.precision: lax.PrecisionLike = precision
self.rngs: nn.Rngs = rngs
# these useless call's are just here to init values in graphdef
_ = self.graphtree_shape
_ = self.graphtree_params_shape
_ = self.mesh
_ = self.model_task
_ = self.model_type
@property
def parameters(self) -> tp.Dict:
from easydel.utils.graph_utils import iter_module_search
parameters = {}
for key, value in iter_module_search(self, nn.Param):
parameters[key] = value.value
return parameters
@property
def graphdef(self) -> nn.GraphDef:
return nn.split(self, nn.Param, ...)[0]
@property
def graphstate(self) -> nn.GraphState:
return nn.split(self, nn.Param, ...)[1]
@property
def graphother(self) -> nn.GraphState:
return nn.split(self, nn.Param, ...)[-1]
@property
def graphtree_params_shape(self) -> tp.Dict:
"""Evaluates the shape of the model's parameters and returns a dictionary."""
graphtree = nn.eval_shape(lambda: nn.split(self, nn.Param, ...)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def graphtree_shape(self) -> tp.Dict:
"""Evaluates the shape of the modeland returns a dictionary."""
graphtree = nn.eval_shape(lambda: nn.split(self)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def mesh(self) -> jax.sharding.Mesh:
"""Returns the mesh from the config."""
return self.config.mesh
@property
def model_task(self) -> tp.Optional[str]:
"""Returns the model task."""
return self._model_task
@property
def model_type(self) -> tp.Optional[str]:
"""Returns the model type."""
return self._model_type
@property
def params(self) -> tp.Dict:
return nn.split(self)[-1]
@cached_property
def causal_mask(self) -> jnp.ndarray:
"""Returns a causal mask from the config."""
return self.config.get_basic_causal_mask()
@cached_property
def frequencies(self) -> jnp.ndarray:
"""Returns frequency values from the config."""
return self.config.get_basic_frequencies()
@cached_property
def static_arguments(self) -> tp.Tuple:
return self.get_static_arguments()
@cached_property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
elif getattr(self, "loss_type", None) is not None:
loss_type = self.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if (
loss_type is None
or loss_type not in LOSS_MAPPING
and getattr(self.config, "loss_type", None) is not None
):
warnings.warn(
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`.",
stacklevel=1,
)
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
@property
def module_dtype(self) -> jnp.dtype:
params_state = nn.split(self, nn.Param, ...)[1].flat_state()
return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def to_dtype(self: SELF, dtype: jnp.dtype) -> SELF:
"""Applies sharding functions to the model's state."""
from easydel.utils.graph_utils import iter_module_search
gdef, state, others = nn.split(self, nn.Param, ...)
def _map(path, val: nn.VariableState):
if val.value is not None:
if not path[-1].startswith("quant_"):
val.value = val.value.astype(dtype)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
module.param_dtype = dtype
return self
[docs] def half(self: SELF, change_runtime_dtype: bool = True) -> SELF:
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float16)
return self._reformat_dtype(jnp.float16)
[docs] def float(self: SELF, change_runtime_dtype: bool = True) -> SELF:
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float32)
return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: SELF, dtype) -> SELF:
from easydel.utils.graph_utils import iter_module_search
for path, module in iter_module_search(self):
if hasattr(module, "dtype"):
if str(type(module.dtype)).endswith(
"lax_numpy._ScalarMeta'>"
): # dont change numpy based dtypes
module.dtype = dtype
self.dtype = dtype
return self
def _reformat_dtype(self: SELF, dtype) -> SELF:
from easydel.utils.graph_utils import iter_module_search
gdef, gtree, others = nn.split(self, nn.Param, ...)
def _map(array):
if array.dtype in [
jnp.bfloat16,
jnp.float16,
jnp.float32,
jnp.float64,
jnp.float_,
]:
array = array.astype(dtype)
return array
gtree = jax.tree_util.tree_map(_map, gtree)
self = nn.merge(gdef, gtree, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
if isinstance(module.param_dtype, jnp.dtype):
module.param_dtype = dtype
self.param_dtype = dtype
return self
def _match_partition_rules(self, partition_rules: tp.Any = None):
return match_partition_rules(
rules=self._get_partition_rules(partition_rules),
tree=self.graphtree_params_shape,
)
@property
def _specs_sharding(self):
def _map(array):
if hasattr(array, "sharding"):
sharding = array.sharding
if isinstance(sharding, NamedSharding):
return sharding.spec
return PartitionSpec()
return nn.from_tree(
jax.tree_util.tree_map(
_map,
nn.to_tree(self),
)
)
@property
def _shardings(self):
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(),
nn.to_tree(self),
)
)
@property
def _named_shardings(self):
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
nn.to_tree(self),
)
)
def _get_mesh(self, mesh: tp.Optional[Mesh] = None) -> Mesh:
"""Retrieves the mesh, either from the provided argument or the config."""
if mesh is None:
if (
not hasattr(self, "config")
or not hasattr(self.config, "mesh")
or self.config.mesh is None
):
raise ValueError(
"A mesh must be provided, either as an argument or through the model config."
)
return self.config.mesh
return mesh
def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike:
"""Retrieves the partition rules from input or the config"""
if partition_rules is None:
if not hasattr(self, "config"):
raise ValueError(
"Partition rules must be provided either as an argument or through the model config."
)
return self.config.get_partition_rules(fully_sharded_data_parallel=True)
return partition_rules
def _apply_sharding_fns(
self: SELF,
sharding_fns: tp.Mapping[str, tp.Callable],
) -> SELF:
"""Applies sharding functions to the model's state."""
gdef, state, others = nn.split(self, nn.Param, ...)
sharding_fns = flatten_dict(sharding_fns)
_shard_keys = list(sharding_fns.keys())
def _map(path, val: nn.VariableState):
if val.value is not None and path in _shard_keys:
try:
val.value = sharding_fns[path](val.value)
except TypeError:
path = map(str, path)
warnings.warn(f"couldn't shard/gather {'.'.join(path)}", stacklevel=1)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
return self
[docs] def shard_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""Shards the model's parameters using the specified partitioning rules and mesh.
Args:
partition_rules (PartitionLike, optional): Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional): The mesh to shard across.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
Returns:
EasyDeLBaseModule: The sharded model.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
shard_fns, _ = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
shard_fns.update(overlay_fns)
self = self._apply_sharding_fns(shard_fns)
return self
[docs] def gather_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""Gathers the model's parameters based on the specified partitioning rules and mesh.
Args:
partition_rules (PartitionLike, optional): Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional): The mesh to gather from.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
Returns:
EasyDeLBaseModule: The gathered model.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
_, gather_fns = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
gather_fns.update(overlay_fns)
return self._apply_sharding_fns(gather_fns)
@property
def _shard_fns(self):
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[0]
@property
def _gather_fns(self):
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[1]
[docs] def fully_shard(self: SELF, partition_rules: PartitionLike = None) -> SELF:
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=x),
match_partition_rules(
self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def fully_gather(self: SELF) -> SELF:
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()),
match_partition_rules(
self._get_partition_rules(None), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def quantize(
self: SELF,
method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT,
block_size: int = 128,
quantization_pattern: tp.Optional[str] = None,
quantize_tensors: bool = True,
verbose: tp.Optional[bool] = None,
) -> SELF:
"""Quantizes the model's linear layers.
Args:
method (EasyDeLQuantizationMethods, optional): The quantization method to use.
block_size (int, optional): The block size for quantization.
quantization_pattern (str, optional): The quantization pattern to use.
quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.`
verbose (bool, optional): Verbose quantizing process
Returns:
EasyDeLBaseModule: The quantized model.
"""
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = EasyQuantizer(
quantization_method=method,
block_size=block_size,
quantization_pattern=quantization_pattern,
)
if verbose is None:
verbose = jax.process_index() == 0
if quantize_tensors:
...
else:
self = quantizer.quantize_linears(
self,
quantization_pattern=quantization_pattern,
verbose=verbose,
)
return self
[docs] def to_state(self) -> EasyDeLState:
"""converts current model to a EasyDeLState"""
from easydel.infra.base_state import EasyDeLState
return EasyDeLState.create(step=0, model=self)
[docs] def to_torch(self, **kwargs):
from easydel.utils.parameters_transformation import module_to_huggingface_model
hf_autoloader = self.get_torch_loader()
model_class = hf_autoloader._model_mapping[type(self.config)]
hf_model = module_to_huggingface_model(
module=self,
base_huggingface_module=model_class,
config=self.config,
dtype=self.param_dtype,
**kwargs,
)
return hf_model
[docs] def get_static_arguments(self) -> tp.Tuple:
return ()
[docs] @classmethod
def lazy_init(cls: tp.Type[SELF], *args, **kwargs) -> SELF:
return nn.eval_shape(lambda: cls(*args, **kwargs))
[docs] def merge_lora_params(self: SELF, pytree: tp.Dict) -> SELF:
from easydel.infra.utils import merge_lora_params
self = merge_lora_params(self, pytree)
return self
[docs] def split_lora_params(self: SELF) -> tp.Dict:
from easydel.infra.utils import split_lora_params
pytree = split_lora_params(self)
return pytree
[docs] def apply_lora_to_layers(
self: SELF,
lora_rank: int,
lora_pattern: tp.Optional[str] = None,
verbose: bool = False,
rngs: tp.Optional[nn.Rngs] = None,
) -> SELF:
from easydel.infra.utils import apply_lora_to_layers
self = apply_lora_to_layers(
self,
lora_pattern=lora_pattern,
lora_rank=lora_rank,
rngs=rngs,
verbose=verbose,
)
return self
[docs] def unwrap_lora_to_layers(self: SELF, verbose: bool = False) -> SELF:
from easydel.infra.utils import unwrap_lora_to_layers
self = unwrap_lora_to_layers(self, verbose=verbose)
return self
@property
def transform_fn(self):
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
if not isinstance(pa[-1], int)
]
layernorm_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
if not isinstance(pa[-1], int)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
shard_fns=self._shard_fns,
)
@property
def _generate_compatible_graphdef(self):
from copy import deepcopy
adjusted_config = deepcopy(self.config)
adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE
dummy = type(self).lazy_init(
config=adjusted_config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
rngs=self.rngs,
)
gdef, _, _ = nn.split(dummy, nn.Param, ...)
return gdef
@property
def _generate_compatible_graphother(self):
from copy import deepcopy
adjusted_config = deepcopy(self.config)
adjusted_config.gradient_checkpointing = EasyDeLGradientCheckPointers.NONE
dummy = type(self).lazy_init(
config=adjusted_config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
rngs=self.rngs,
)
_, _, gother = nn.split(dummy, nn.Param, ...)
gother = traversals.recreate_meta_values(gother)
return gother
@property
def params_sharding(self) -> tp.Dict:
return jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
self.split_params_dict(),
)
[docs] def merge_params(self, tree):
"""merge state to the current model"""
gdef, _, gother = nn.split(self, nn.Param, ...)
self = nn.merge(gdef, tree, gother)
return self
[docs] def split_params(self):
"""split the model parameters"""
return nn.split(self, nn.Param, ...)[1]
[docs] def split_params_dict(
self,
extract_fn: tp.Optional[tp.Callable] = None,
remove_none: bool = True,
) -> tp.Dict:
"""Splits the model parameters and returns them as a dictionary, removing `VariableState` from the tree.
Args:
extract_fn (tp.Optional[tp.Callable], optional): Function to extract values from the parameters.
remove_none (bool, optional): Whether to remove `None` values from the dictionary.
Returns:
tp.Dict: The dictionary of split parameters.
"""
flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn))
if remove_none:
flat_params = {
k: v.value if hasattr(v, "value") else v
for k, v in flat_params.items()
if (v.value if hasattr(v, "value") else v) is not None
}
else:
flat_params = {
k: v.value if hasattr(v, "value") else v for k, v in flat_params.items()
}
return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: SELF, params_dict: tp.Dict) -> SELF:
"""Merges the model parameters from a dictionary into the current model.
Args:
params_dict (tp.Dict): A dictionary containing the parameters to merge.
Returns:
EasyDeLBaseModule: The model with merged parameters.
"""
current_state = self.split_params().flat_state()
if not is_flatten(params_dict):
params_dict = flatten_dict(params_dict)
for key, value in params_dict.items():
if key in current_state:
current_state[key].value = value
else:
raise KeyError(f"Parameter key {key} not found in the current model state.")
self = self.merge_params(unflatten_dict(current_state))
return self
def _flop(self, *args, **kwargs) -> tp.Optional[float]:
"""Calculates the FLOP (Floating Point Operations) from JaxPr"""
from .utils import count_flop_jaxpr
return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs))
@property
def pure_transform_fn(self):
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
if not isinstance(pa[-1], int)
]
layernorm_path = [
".".join(tuple(map(str, pa)))
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
if not isinstance(pa[-1], int)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
)
@property
def _default_loss_config(self) -> tp.Optional[LossConfig]:
return None
@_default_loss_config.setter
def _default_loss_config(self, val):
return val
[docs] def compute_loss(
self,
*,
labels: tp.Optional[chex.Array] = None,
loss_config: tp.Optional[LossConfig] = None,
loss_kwargs: tp.Optional[tp.Dict] = None,
**batch,
) -> tp.Tuple[tp.Any, LossMetrics]:
"""basic `compute_loss` call"""
if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__:
labels = batch.get("input_ids", None)
if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__:
if loss_config is None:
assert hasattr(self.config, "num_labels"), (
"in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`"
)
loss_config = LossConfig(num_labels=self.config.num_labels)
assert labels is not None, "`labels` can not be `None` for computing loss."
loss_kwargs = loss_kwargs or {}
batch.pop("return_dict", None)
outputs = self(**batch, return_dict=True)
loss_output: LossMetrics = self.loss_function(
labels=labels,
config=loss_config,
paxis=self.config.partition_axis,
**loss_kwargs,
**outputs,
**batch,
)
if hasattr(outputs, "aux_loss"):
if outputs.aux_loss is not None:
loss_output.loss = loss_output.loss + outputs.aux_loss
outputs = outputs.replace(loss=loss_output.loss)
return outputs, loss_output