# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
import typing as tp
import warnings
from functools import cached_property, partial
import chex
import flax
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict
from .base_config import EasyDeLBaseConfig
from .etils import EasyDeLQuantizationMethods
from .loss_utils import (
LOSS_MAPPING,
ForCausalLMLoss,
ForSequenceClassificationLoss,
LossConfig,
LossMetrics,
)
from .mixins import (
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
)
if tp.TYPE_CHECKING:
from easydel.infra.base_state import EasyDeLState
else:
EasyDeLState = tp.Any
PartitionLike = tp.Optional[
tp.Union[
tp.Mapping[str, tp.Callable],
tp.Mapping[tuple, tp.Callable],
]
]
logger = get_logger(__name__)
_CP = tp.TypeVar("CP")
SELF = tp.TypeVar("SELF")
[docs]class EasyDeLBaseModule(
nn.Module,
BaseModuleProtocol,
EasyBridgeMixin,
EasyGenerationMixin,
):
"""
Base class for EasyDeL modules, providing common functionalities for model initialization,
parameter handling, and integration with the EasyDeL ecosystem.
"""
config_class: tp.Type[EasyDeLBaseConfig]
base_model_prefix: str
_model_task: tp.Optional[str] = None
_model_type: tp.Optional[str] = None
def __init__(
self,
config: tp.Union[EasyDeLBaseConfig, _CP],
dtype: jnp.dtype,
param_dtype: jnp.dtype,
precision: lax.PrecisionLike,
rngs: nn.Rngs,
):
"""Initializes the EasyDeLBaseModule.
Args:
config (EasyDeLBaseConfig): The model configuration.
dtype (jnp.dtype): The data type for computation.
param_dtype (jnp.dtype): The data type for parameters.
precision (jax.lax.PrecisionLike): The numerical precision.
rngs (nn.Rngs): The random number generators.
"""
self.config: tp.Union[EasyDeLBaseConfig, _CP] = config
self.dtype: jnp.dtype = dtype
self.param_dtype: jnp.dtype = param_dtype
self.precision: lax.PrecisionLike = precision
self.rngs: nn.Rngs = rngs
# these useless call's are just here to init values in graphdef
_ = self.graphtree_shape
_ = self.graphtree_params_shape
_ = self.mesh
_ = self.model_task
_ = self.model_type
@property
def parameters(self) -> tp.Dict:
from easydel.utils.graph_utils import iter_module_search
parameters = {}
for key, value in iter_module_search(self, nn.Param):
parameters[key] = value.value
return parameters
@property
def graphtree_params_shape(self) -> tp.Dict:
"""Evaluates the shape of the model's parameters and returns a dictionary."""
graphtree = nn.eval_shape(lambda: nn.split(self, nn.Param, ...)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def graphtree_shape(self) -> tp.Dict:
"""Evaluates the shape of the modeland returns a dictionary."""
graphtree = nn.eval_shape(lambda: nn.split(self)[1])
flattened_tree = flatten_dict(graphtree)
param_shapes = {key: val.value for key, val in flattened_tree.items()}
return unflatten_dict(param_shapes)
@property
def mesh(self) -> jax.sharding.Mesh:
"""Returns the mesh from the config."""
return self.config.mesh
@property
def model_task(self) -> tp.Optional[str]:
"""Returns the model task."""
return self._model_task
@property
def model_type(self) -> tp.Optional[str]:
"""Returns the model type."""
return self._model_type
@property
def params(self) -> tp.Dict:
return nn.split(self)[-1]
@cached_property
def causal_mask(self) -> jnp.ndarray:
"""Returns a causal mask from the config."""
return self.config.get_basic_causal_mask()
@cached_property
def frequencies(self) -> jnp.ndarray:
"""Returns frequency values from the config."""
return self.config.get_basic_frequencies()
@cached_property
def static_arguments(self) -> tp.Tuple:
return self.get_static_arguments()
@cached_property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
elif getattr(self, "loss_type", None) is not None:
loss_type = self.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if (
loss_type is None
or loss_type not in LOSS_MAPPING
and getattr(self.config, "loss_type", None) is not None
):
warnings.warn(
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`.",
stacklevel=1,
)
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
@property
def module_dtype(self) -> jnp.dtype:
params_state = nn.split(self, nn.Param, ...)[1].flat_state()
return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def to_dtype(self: SELF, dtype: jnp.dtype) -> SELF:
"""Applies sharding functions to the model's state."""
from easydel.utils.graph_utils import iter_module_search
gdef, state, others = nn.split(self, nn.Param, ...)
def _map(path, val: nn.VariableState):
if val.value is not None:
if not path[-1].startswith("quant_"):
val.value = val.value.astype(dtype)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
module.param_dtype = dtype
return self
[docs] def half(self: SELF, change_runtime_dtype: bool = True) -> SELF:
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float16)
return self._reformat_dtype(jnp.float16)
[docs] def float(self: SELF, change_runtime_dtype: bool = True) -> SELF:
if change_runtime_dtype:
self = self._reformat_runtime_dtype(jnp.float32)
return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: SELF, dtype) -> SELF:
from easydel.utils.graph_utils import iter_module_search
for path, module in iter_module_search(self):
if hasattr(module, "dtype"):
if str(type(module.dtype)).endswith(
"lax_numpy._ScalarMeta'>"
): # dont change numpy based dtypes
module.dtype = dtype
self.dtype = dtype
return self
def _reformat_dtype(self: SELF, dtype) -> SELF:
from easydel.utils.graph_utils import iter_module_search
gdef, gtree, others = nn.split(self, nn.Param, ...)
def _map(array):
if array.dtype in [
jnp.bfloat16,
jnp.float16,
jnp.float32,
jnp.float64,
jnp.float_,
]:
array = array.astype(dtype)
return array
gtree = jax.tree_util.tree_map(_map, gtree)
self = nn.merge(gdef, gtree, others)
for path, module in iter_module_search(self):
if hasattr(module, "param_dtype"):
if isinstance(module.param_dtype, jnp.dtype):
module.param_dtype = dtype
self.param_dtype = dtype
return self
def _match_partition_rules(self, partition_rules: tp.Any = None):
return match_partition_rules(
rules=self._get_partition_rules(partition_rules),
tree=self.graphtree_params_shape,
)
@property
def _specs_sharding(self):
def _map(array):
if hasattr(array, "sharding"):
sharding = array.sharding
if isinstance(sharding, NamedSharding):
return sharding.spec
return PartitionSpec()
return nn.from_tree(
jax.tree_util.tree_map(
_map,
nn.to_tree(self),
)
)
@property
def _shardings(self):
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(),
nn.to_tree(self),
)
)
@property
def _named_shardings(self):
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
nn.to_tree(self),
)
)
def _get_mesh(self, mesh: tp.Optional[Mesh] = None) -> Mesh:
"""Retrieves the mesh, either from the provided argument or the config."""
if mesh is None:
if (
not hasattr(self, "config")
or not hasattr(self.config, "mesh")
or self.config.mesh is None
):
raise ValueError(
"A mesh must be provided, either as an argument or through the model config."
)
return self.config.mesh
return mesh
def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike:
"""Retrieves the partition rules from input or the config"""
if partition_rules is None:
if not hasattr(self, "config"):
raise ValueError(
"Partition rules must be provided either as an argument or through the model config."
)
return self.config.get_partition_rules(fully_sharded_data_parallel=True)
return partition_rules
def _apply_sharding_fns(
self: SELF,
sharding_fns: tp.Mapping[str, tp.Callable],
) -> SELF:
"""Applies sharding functions to the model's state."""
gdef, state, others = nn.split(self, nn.Param, ...)
sharding_fns = flatten_dict(sharding_fns)
_shard_keys = list(sharding_fns.keys())
def _map(path, val: nn.VariableState):
if val.value is not None and path in _shard_keys:
try:
val.value = sharding_fns[path](val.value)
except TypeError:
path = map(str, path)
warnings.warn(f"couldn't shard/gather {'.'.join(path)}", stacklevel=1)
return val
state.update(state.map(_map))
self = nn.merge(gdef, state, others)
return self
[docs] def shard_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""Shards the model's parameters using the specified partitioning rules and mesh.
Args:
partition_rules (PartitionLike, optional): Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional): The mesh to shard across.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
Returns:
EasyDeLBaseModule: The sharded model.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
shard_fns, _ = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
shard_fns.update(overlay_fns)
self = self._apply_sharding_fns(shard_fns)
return self
[docs] def gather_model(
self: SELF,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None,
) -> SELF:
"""Gathers the model's parameters based on the specified partitioning rules and mesh.
Args:
partition_rules (PartitionLike, optional): Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional): The mesh to gather from.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
Returns:
EasyDeLBaseModule: The gathered model.
"""
mesh = self._get_mesh(mesh)
partition_rules = self._get_partition_rules(partition_rules)
partition_specs = match_partition_rules(
rules=partition_rules,
tree=self.graphtree_params_shape,
)
_, gather_fns = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
if overlay_fns is not None:
gather_fns.update(overlay_fns)
return self._apply_sharding_fns(gather_fns)
@property
def _shard_fns(self):
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[0]
@property
def _gather_fns(self):
mesh = self._get_mesh(None)
partition_specs = match_partition_rules(
rules=self._get_partition_rules(None),
tree=self.graphtree_params_shape,
)
return make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)[1]
[docs] def fully_shard(self: SELF, partition_rules: PartitionLike = None) -> SELF:
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=x),
match_partition_rules(
self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def fully_gather(self: SELF) -> SELF:
class ShardState(flax.struct.PyTreeNode):
graphdef: nn.GraphDef
graphstate: nn.GraphState
gdef, gstate = nn.split(self)
mock = ShardState(graphdef=gdef, graphstate=gstate)
shardings = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()),
match_partition_rules(
self._get_partition_rules(None), nn.eval_shape(lambda: mock)
),
)
@partial(jax.jit, out_shardings=shardings)
def _call(cl):
return cl
mock = _call(mock)
self = nn.merge(mock.graphdef, mock.graphstate)
return self
[docs] def quantize(
self: SELF,
method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT,
block_size: int = 128,
quantization_pattern: tp.Optional[str] = None,
quantize_tensors: bool = True,
verbose: tp.Optional[bool] = None,
) -> SELF:
"""Quantizes the model's linear layers.
Args:
method (EasyDeLQuantizationMethods, optional): The quantization method to use.
block_size (int, optional): The block size for quantization.
quantization_pattern (str, optional): The quantization pattern to use.
quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.`
verbose (bool, optional): Verbose quantizing process
Returns:
EasyDeLBaseModule: The quantized model.
"""
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = EasyQuantizer(
quantization_method=method,
block_size=block_size,
quantization_pattern=quantization_pattern,
)
if verbose is None:
verbose = jax.process_index() == 0
if quantize_tensors:
...
else:
self = quantizer.quantize_linears(
self,
quantization_pattern=quantization_pattern,
verbose=verbose,
)
return self
[docs] def to_state(self) -> EasyDeLState:
"""converts current model to a EasyDeLState"""
from easydel.infra.base_state import EasyDeLState
return EasyDeLState.create(step=0, model=self)
[docs] def to_torch(self, **kwargs):
from easydel.utils.parameters_transformation import module_to_huggingface_model
hf_autoloader = self.get_torch_loader()
model_class = hf_autoloader._model_mapping[type(self.config)]
hf_model = module_to_huggingface_model(
module=self,
base_huggingface_module=model_class,
config=self.config,
dtype=self.param_dtype,
**kwargs,
)
return hf_model
[docs] def get_static_arguments(self) -> tp.Tuple:
return ()
[docs] @classmethod
def lazy_init(cls: tp.Type[SELF], *args, **kwargs) -> SELF:
return nn.eval_shape(lambda: cls(*args, **kwargs))
[docs] def merge_lora_params(self: SELF, pytree: tp.Dict) -> SELF:
from easydel.infra.utils import merge_lora_params
self = merge_lora_params(self, pytree)
return self
[docs] def split_lora_params(self: SELF) -> tp.Dict:
from easydel.infra.utils import split_lora_params
pytree = split_lora_params(self)
return pytree
[docs] def apply_lora_to_layers(
self: SELF,
lora_rank: int,
lora_pattern: tp.Optional[str] = None,
verbose: bool = False,
rngs: tp.Optional[nn.Rngs] = None,
) -> SELF:
from easydel.infra.utils import apply_lora_to_layers
self = apply_lora_to_layers(
self,
lora_pattern=lora_pattern,
lora_rank=lora_rank,
rngs=rngs,
verbose=verbose,
)
return self
[docs] def unwrap_lora_to_layers(self: SELF, verbose: bool = False) -> SELF:
from easydel.infra.utils import unwrap_lora_to_layers
self = unwrap_lora_to_layers(self, verbose=verbose)
return self
@property
def transform_fn(self):
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
pa[-1]
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
if not isinstance(pa[-1], int)
]
layernorm_path = [
pa[-1]
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
if not isinstance(pa[-1], int)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
shard_fns=self._shard_fns,
)
@property
def params_sharding(self) -> tp.Dict:
return jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
self.split_params_dict(),
)
[docs] def merge_params(self, tree):
"""merge state to the current model"""
gdef, _, gother = nn.split(self, nn.Param, ...)
self = nn.merge(gdef, tree, gother)
return self
[docs] def split_params(self):
"""split the model parameters"""
return nn.split(self, nn.Param, ...)[1]
[docs] def split_params_dict(
self,
extract_fn: tp.Optional[tp.Callable] = None,
remove_none: bool = True,
) -> tp.Dict:
"""Splits the model parameters and returns them as a dictionary, removing `VariableState` from the tree.
Args:
extract_fn (tp.Optional[tp.Callable], optional): Function to extract values from the parameters.
remove_none (bool, optional): Whether to remove `None` values from the dictionary.
Returns:
tp.Dict: The dictionary of split parameters.
"""
flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn))
if remove_none:
flat_params = {
k: v.value if hasattr(v, "value") else v
for k, v in flat_params.items()
if (v.value if hasattr(v, "value") else v) is not None
}
else:
flat_params = {
k: v.value if hasattr(v, "value") else v for k, v in flat_params.items()
}
return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: SELF, params_dict: tp.Dict) -> SELF:
"""Merges the model parameters from a dictionary into the current model.
Args:
params_dict (tp.Dict): A dictionary containing the parameters to merge.
Returns:
EasyDeLBaseModule: The model with merged parameters.
"""
current_state = self.split_params().flat_state()
if not is_flatten(params_dict):
params_dict = flatten_dict(params_dict)
for key, value in params_dict.items():
if key in current_state:
current_state[key].value = value
else:
raise KeyError(f"Parameter key {key} not found in the current model state.")
self = self.merge_params(unflatten_dict(current_state))
return self
def _flop(self, *args, **kwargs) -> tp.Optional[float]:
"""Calculates the FLOP (Floating Point Operations) from JaxPr"""
from .utils import count_flop_jaxpr
return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs))
@property
def pure_transform_fn(self):
from easydel.utils import graph_utils
from easydel.utils.parameters_transformation import torch_dict_to_easydel_params
embedding_path = [
pa[-1]
for pa, _ in graph_utils.iter_module_search(self, nn.Embed)
if not isinstance(pa[-1], int)
]
layernorm_path = [
pa[-1]
for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm)
if not isinstance(pa[-1], int)
]
return partial(
torch_dict_to_easydel_params,
embedding_layer_names=embedding_path,
layernorm_names=layernorm_path,
dtype=self.param_dtype,
)
@property
def _default_loss_config(self) -> tp.Optional[LossConfig]:
return None
@_default_loss_config.setter
def _default_loss_config(self, val):
return val
[docs] def compute_loss(
self,
*,
labels: tp.Optional[chex.Array] = None,
loss_config: tp.Optional[LossConfig] = None,
loss_kwargs: tp.Optional[tp.Dict] = None,
**batch,
) -> tp.Tuple[tp.Any, LossMetrics]:
"""basic `compute_loss` call"""
if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__:
labels = batch.get("input_ids", None)
if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__:
if loss_config is None:
assert hasattr(self.config, "num_labels"), (
"in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`"
)
loss_config = LossConfig(num_labels=self.config.num_labels)
assert labels is not None, "`labels` can not be `None` for computing loss."
loss_kwargs = loss_kwargs or {}
batch.pop("return_dict", None)
outputs = self(**batch, return_dict=True)
loss_output: LossMetrics = self.loss_function(
labels=labels,
config=loss_config,
paxis=self.config.partition_axis,
**loss_kwargs,
**outputs,
**batch,
)
if hasattr(outputs, "aux_loss"):
if outputs.aux_loss is not None:
loss_output.loss = loss_output.loss + outputs.aux_loss
outputs = outputs.replace(loss=loss_output.loss)
return outputs, loss_output