Source code for easydel.infra.base_module

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import typing as tp
import warnings
from functools import cached_property, partial

import chex
import flax
import flax.struct
import jax
import jax.extend
import jax.tree_util
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, is_flatten, unflatten_dict

from .base_config import EasyDeLBaseConfig
from .etils import EasyDeLQuantizationMethods
from .loss_utils import (
	LOSS_MAPPING,
	ForCausalLMLoss,
	ForSequenceClassificationLoss,
	LossConfig,
	LossMetrics,
)
from .mixins import (
	BaseModuleProtocol,
	EasyBridgeMixin,
	EasyGenerationMixin,
)

if tp.TYPE_CHECKING:
	from easydel.infra.base_state import EasyDeLState
else:
	EasyDeLState = tp.Any

PartitionLike = tp.Optional[
	tp.Union[
		tp.Mapping[str, tp.Callable],
		tp.Mapping[tuple, tp.Callable],
	]
]


logger = get_logger(__name__)

_CP = tp.TypeVar("CP")
SELF = tp.TypeVar("SELF")


[docs]class EasyDeLBaseModule( nn.Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin, ): """ Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem. """ config_class: tp.Type[EasyDeLBaseConfig] base_model_prefix: str _model_task: tp.Optional[str] = None _model_type: tp.Optional[str] = None def __init__( self, config: tp.Union[EasyDeLBaseConfig, _CP], dtype: jnp.dtype, param_dtype: jnp.dtype, precision: lax.PrecisionLike, rngs: nn.Rngs, ): """Initializes the EasyDeLBaseModule. Args: config (EasyDeLBaseConfig): The model configuration. dtype (jnp.dtype): The data type for computation. param_dtype (jnp.dtype): The data type for parameters. precision (jax.lax.PrecisionLike): The numerical precision. rngs (nn.Rngs): The random number generators. """ self.config: tp.Union[EasyDeLBaseConfig, _CP] = config self.dtype: jnp.dtype = dtype self.param_dtype: jnp.dtype = param_dtype self.precision: lax.PrecisionLike = precision self.rngs: nn.Rngs = rngs # these useless call's are just here to init values in graphdef _ = self.graphtree_shape _ = self.graphtree_params_shape _ = self.mesh _ = self.model_task _ = self.model_type @property def parameters(self) -> tp.Dict: from easydel.utils.graph_utils import iter_module_search parameters = {} for key, value in iter_module_search(self, nn.Param): parameters[key] = value.value return parameters @property def graphtree_params_shape(self) -> tp.Dict: """Evaluates the shape of the model's parameters and returns a dictionary.""" graphtree = nn.eval_shape(lambda: nn.split(self, nn.Param, ...)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: val.value for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def graphtree_shape(self) -> tp.Dict: """Evaluates the shape of the modeland returns a dictionary.""" graphtree = nn.eval_shape(lambda: nn.split(self)[1]) flattened_tree = flatten_dict(graphtree) param_shapes = {key: val.value for key, val in flattened_tree.items()} return unflatten_dict(param_shapes) @property def mesh(self) -> jax.sharding.Mesh: """Returns the mesh from the config.""" return self.config.mesh @property def model_task(self) -> tp.Optional[str]: """Returns the model task.""" return self._model_task @property def model_type(self) -> tp.Optional[str]: """Returns the model type.""" return self._model_type @property def params(self) -> tp.Dict: return nn.split(self)[-1] @cached_property def causal_mask(self) -> jnp.ndarray: """Returns a causal mask from the config.""" return self.config.get_basic_causal_mask() @cached_property def frequencies(self) -> jnp.ndarray: """Returns frequency values from the config.""" return self.config.get_basic_frequencies() @cached_property def static_arguments(self) -> tp.Tuple: return self.get_static_arguments() @cached_property def loss_function(self): if getattr(self.config, "loss_type", None) is not None: loss_type = self.config.loss_type elif getattr(self, "loss_type", None) is not None: loss_type = self.loss_type else: loss_type = self.__class__.__name__ if loss_type not in LOSS_MAPPING: loss_groups = f"({'|'.join(LOSS_MAPPING)})" loss_type = re.findall(loss_groups, self.__class__.__name__) if len(loss_type) > 0: loss_type = loss_type[0] else: loss_type = None if ( loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None ): warnings.warn( f"`loss_type={loss_type}` was set in the config but it is unrecognised." f"Using the default loss: `ForCausalLMLoss`.", stacklevel=1, ) loss_type = "ForCausalLM" return LOSS_MAPPING[loss_type] @property def module_dtype(self) -> jnp.dtype: params_state = nn.split(self, nn.Param, ...)[1].flat_state() return jax.tree_util.tree_leaves(params_state)[0].dtype
[docs] def to_dtype(self: SELF, dtype: jnp.dtype) -> SELF: """Applies sharding functions to the model's state.""" from easydel.utils.graph_utils import iter_module_search gdef, state, others = nn.split(self, nn.Param, ...) def _map(path, val: nn.VariableState): if val.value is not None: if not path[-1].startswith("quant_"): val.value = val.value.astype(dtype) return val state.update(state.map(_map)) self = nn.merge(gdef, state, others) for path, module in iter_module_search(self): if hasattr(module, "param_dtype"): module.param_dtype = dtype return self
[docs] def half(self: SELF, change_runtime_dtype: bool = True) -> SELF: if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float16) return self._reformat_dtype(jnp.float16)
[docs] def float(self: SELF, change_runtime_dtype: bool = True) -> SELF: if change_runtime_dtype: self = self._reformat_runtime_dtype(jnp.float32) return self._reformat_dtype(jnp.float32)
def _reformat_runtime_dtype(self: SELF, dtype) -> SELF: from easydel.utils.graph_utils import iter_module_search for path, module in iter_module_search(self): if hasattr(module, "dtype"): if str(type(module.dtype)).endswith( "lax_numpy._ScalarMeta'>" ): # dont change numpy based dtypes module.dtype = dtype self.dtype = dtype return self def _reformat_dtype(self: SELF, dtype) -> SELF: from easydel.utils.graph_utils import iter_module_search gdef, gtree, others = nn.split(self, nn.Param, ...) def _map(array): if array.dtype in [ jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.float_, ]: array = array.astype(dtype) return array gtree = jax.tree_util.tree_map(_map, gtree) self = nn.merge(gdef, gtree, others) for path, module in iter_module_search(self): if hasattr(module, "param_dtype"): if isinstance(module.param_dtype, jnp.dtype): module.param_dtype = dtype self.param_dtype = dtype return self def _match_partition_rules(self, partition_rules: tp.Any = None): return match_partition_rules( rules=self._get_partition_rules(partition_rules), tree=self.graphtree_params_shape, ) @property def _specs_sharding(self): def _map(array): if hasattr(array, "sharding"): sharding = array.sharding if isinstance(sharding, NamedSharding): return sharding.spec return PartitionSpec() return nn.from_tree( jax.tree_util.tree_map( _map, nn.to_tree(self), ) ) @property def _shardings(self): return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else PartitionSpec(), nn.to_tree(self), ) ) @property def _named_shardings(self): return nn.from_tree( jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, nn.to_tree(self), ) ) def _get_mesh(self, mesh: tp.Optional[Mesh] = None) -> Mesh: """Retrieves the mesh, either from the provided argument or the config.""" if mesh is None: if ( not hasattr(self, "config") or not hasattr(self.config, "mesh") or self.config.mesh is None ): raise ValueError( "A mesh must be provided, either as an argument or through the model config." ) return self.config.mesh return mesh def _get_partition_rules(self, partition_rules: PartitionLike) -> PartitionLike: """Retrieves the partition rules from input or the config""" if partition_rules is None: if not hasattr(self, "config"): raise ValueError( "Partition rules must be provided either as an argument or through the model config." ) return self.config.get_partition_rules(fully_sharded_data_parallel=True) return partition_rules def _apply_sharding_fns( self: SELF, sharding_fns: tp.Mapping[str, tp.Callable], ) -> SELF: """Applies sharding functions to the model's state.""" gdef, state, others = nn.split(self, nn.Param, ...) sharding_fns = flatten_dict(sharding_fns) _shard_keys = list(sharding_fns.keys()) def _map(path, val: nn.VariableState): if val.value is not None and path in _shard_keys: try: val.value = sharding_fns[path](val.value) except TypeError: path = map(str, path) warnings.warn(f"couldn't shard/gather {'.'.join(path)}", stacklevel=1) return val state.update(state.map(_map)) self = nn.merge(gdef, state, others) return self
[docs] def shard_model( self: SELF, partition_rules: PartitionLike = None, mesh: tp.Optional[Mesh] = None, overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None, ) -> SELF: """Shards the model's parameters using the specified partitioning rules and mesh. Args: partition_rules (PartitionLike, optional): Partitioning rules for sharding. mesh (jax.sharding.Mesh, optional): The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model. Returns: EasyDeLBaseModule: The sharded model. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules( rules=partition_rules, tree=self.graphtree_params_shape, ) shard_fns, _ = make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, ) if overlay_fns is not None: shard_fns.update(overlay_fns) self = self._apply_sharding_fns(shard_fns) return self
[docs] def gather_model( self: SELF, partition_rules: PartitionLike = None, mesh: tp.Optional[Mesh] = None, overlay_fns: tp.Optional[tp.Mapping[str, tp.Callable]] = None, ) -> SELF: """Gathers the model's parameters based on the specified partitioning rules and mesh. Args: partition_rules (PartitionLike, optional): Partitioning rules for gathering. mesh (jax.sharding.Mesh, optional): The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model. Returns: EasyDeLBaseModule: The gathered model. """ mesh = self._get_mesh(mesh) partition_rules = self._get_partition_rules(partition_rules) partition_specs = match_partition_rules( rules=partition_rules, tree=self.graphtree_params_shape, ) _, gather_fns = make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, ) if overlay_fns is not None: gather_fns.update(overlay_fns) return self._apply_sharding_fns(gather_fns)
@property def _shard_fns(self): mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[0] @property def _gather_fns(self): mesh = self._get_mesh(None) partition_specs = match_partition_rules( rules=self._get_partition_rules(None), tree=self.graphtree_params_shape, ) return make_shard_and_gather_fns( partition_specs=partition_specs, mesh=mesh, )[1]
[docs] def fully_shard(self: SELF, partition_rules: PartitionLike = None) -> SELF: class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=x), match_partition_rules( self._get_partition_rules(partition_rules), nn.eval_shape(lambda: mock) ), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def fully_gather(self: SELF) -> SELF: class ShardState(flax.struct.PyTreeNode): graphdef: nn.GraphDef graphstate: nn.GraphState gdef, gstate = nn.split(self) mock = ShardState(graphdef=gdef, graphstate=gstate) shardings = jax.tree_util.tree_map( lambda x: NamedSharding(mesh=self.mesh, spec=PartitionSpec()), match_partition_rules( self._get_partition_rules(None), nn.eval_shape(lambda: mock) ), ) @partial(jax.jit, out_shardings=shardings) def _call(cl): return cl mock = _call(mock) self = nn.merge(mock.graphdef, mock.graphstate) return self
[docs] def quantize( self: SELF, method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: tp.Optional[str] = None, quantize_tensors: bool = True, verbose: tp.Optional[bool] = None, ) -> SELF: """Quantizes the model's linear layers. Args: method (EasyDeLQuantizationMethods, optional): The quantization method to use. block_size (int, optional): The block size for quantization. quantization_pattern (str, optional): The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process Returns: EasyDeLBaseModule: The quantized model. """ from easydel.layers.quantization.quantizers import EasyQuantizer quantizer = EasyQuantizer( quantization_method=method, block_size=block_size, quantization_pattern=quantization_pattern, ) if verbose is None: verbose = jax.process_index() == 0 if quantize_tensors: ... else: self = quantizer.quantize_linears( self, quantization_pattern=quantization_pattern, verbose=verbose, ) return self
[docs] def to_state(self) -> EasyDeLState: """converts current model to a EasyDeLState""" from easydel.infra.base_state import EasyDeLState return EasyDeLState.create(step=0, model=self)
[docs] def to_torch(self, **kwargs): from easydel.utils.parameters_transformation import module_to_huggingface_model hf_autoloader = self.get_torch_loader() model_class = hf_autoloader._model_mapping[type(self.config)] hf_model = module_to_huggingface_model( module=self, base_huggingface_module=model_class, config=self.config, dtype=self.param_dtype, **kwargs, ) return hf_model
[docs] def prepare_inputs_for_call(self, **kwargs): return kwargs
[docs] def get_static_arguments(self) -> tp.Tuple: return ()
[docs] @classmethod def lazy_init(cls: tp.Type[SELF], *args, **kwargs) -> SELF: return nn.eval_shape(lambda: cls(*args, **kwargs))
[docs] def merge_lora_params(self: SELF, pytree: tp.Dict) -> SELF: from easydel.infra.utils import merge_lora_params self = merge_lora_params(self, pytree) return self
[docs] def split_lora_params(self: SELF) -> tp.Dict: from easydel.infra.utils import split_lora_params pytree = split_lora_params(self) return pytree
[docs] def apply_lora_to_layers( self: SELF, lora_rank: int, lora_pattern: tp.Optional[str] = None, verbose: bool = False, rngs: tp.Optional[nn.Rngs] = None, ) -> SELF: from easydel.infra.utils import apply_lora_to_layers self = apply_lora_to_layers( self, lora_pattern=lora_pattern, lora_rank=lora_rank, rngs=rngs, verbose=verbose, ) return self
[docs] def unwrap_lora_to_layers(self: SELF, verbose: bool = False) -> SELF: from easydel.infra.utils import unwrap_lora_to_layers self = unwrap_lora_to_layers(self, verbose=verbose) return self
@property def transform_fn(self): from easydel.utils import graph_utils from easydel.utils.parameters_transformation import torch_dict_to_easydel_params embedding_path = [ pa[-1] for pa, _ in graph_utils.iter_module_search(self, nn.Embed) if not isinstance(pa[-1], int) ] layernorm_path = [ pa[-1] for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm) if not isinstance(pa[-1], int) ] return partial( torch_dict_to_easydel_params, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, dtype=self.param_dtype, shard_fns=self._shard_fns, ) @property def params_sharding(self) -> tp.Dict: return jax.tree_util.tree_map( lambda x: x.sharding if hasattr(x, "sharding") else None, self.split_params_dict(), )
[docs] def merge_params(self, tree): """merge state to the current model""" gdef, _, gother = nn.split(self, nn.Param, ...) self = nn.merge(gdef, tree, gother) return self
[docs] def split_params(self): """split the model parameters""" return nn.split(self, nn.Param, ...)[1]
[docs] def split_params_dict( self, extract_fn: tp.Optional[tp.Callable] = None, remove_none: bool = True, ) -> tp.Dict: """Splits the model parameters and returns them as a dictionary, removing `VariableState` from the tree. Args: extract_fn (tp.Optional[tp.Callable], optional): Function to extract values from the parameters. remove_none (bool, optional): Whether to remove `None` values from the dictionary. Returns: tp.Dict: The dictionary of split parameters. """ flat_params = flatten_dict(self.split_params().to_pure_dict(extract_fn=extract_fn)) if remove_none: flat_params = { k: v.value if hasattr(v, "value") else v for k, v in flat_params.items() if (v.value if hasattr(v, "value") else v) is not None } else: flat_params = { k: v.value if hasattr(v, "value") else v for k, v in flat_params.items() } return unflatten_dict(flat_params)
[docs] def merge_params_dict(self: SELF, params_dict: tp.Dict) -> SELF: """Merges the model parameters from a dictionary into the current model. Args: params_dict (tp.Dict): A dictionary containing the parameters to merge. Returns: EasyDeLBaseModule: The model with merged parameters. """ current_state = self.split_params().flat_state() if not is_flatten(params_dict): params_dict = flatten_dict(params_dict) for key, value in params_dict.items(): if key in current_state: current_state[key].value = value else: raise KeyError(f"Parameter key {key} not found in the current model state.") self = self.merge_params(unflatten_dict(current_state)) return self
def _flop(self, *args, **kwargs) -> tp.Optional[float]: """Calculates the FLOP (Floating Point Operations) from JaxPr""" from .utils import count_flop_jaxpr return count_flop_jaxpr(jax.make_jaxpr(self.__call__)(*args, **kwargs)) @property def pure_transform_fn(self): from easydel.utils import graph_utils from easydel.utils.parameters_transformation import torch_dict_to_easydel_params embedding_path = [ pa[-1] for pa, _ in graph_utils.iter_module_search(self, nn.Embed) if not isinstance(pa[-1], int) ] layernorm_path = [ pa[-1] for pa, _ in graph_utils.iter_module_search(self, nn.LayerNorm) if not isinstance(pa[-1], int) ] return partial( torch_dict_to_easydel_params, embedding_layer_names=embedding_path, layernorm_names=layernorm_path, dtype=self.param_dtype, ) @property def _default_loss_config(self) -> tp.Optional[LossConfig]: return None @_default_loss_config.setter def _default_loss_config(self, val): return val
[docs] def compute_loss( self, *, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch, ) -> tp.Tuple[tp.Any, LossMetrics]: """basic `compute_loss` call""" if labels is None and self.loss_function.__name__ == ForCausalLMLoss.__name__: labels = batch.get("input_ids", None) if self.loss_function.__name__ == ForSequenceClassificationLoss.__name__: if loss_config is None: assert hasattr(self.config, "num_labels"), ( "in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`" ) loss_config = LossConfig(num_labels=self.config.num_labels) assert labels is not None, "`labels` can not be `None` for computing loss." loss_kwargs = loss_kwargs or {} batch.pop("return_dict", None) outputs = self(**batch, return_dict=True) loss_output: LossMetrics = self.loss_function( labels=labels, config=loss_config, paxis=self.config.partition_axis, **loss_kwargs, **outputs, **batch, ) if hasattr(outputs, "aux_loss"): if outputs.aux_loss is not None: loss_output.loss = loss_output.loss + outputs.aux_loss outputs = outputs.replace(loss=loss_output.loss) return outputs, loss_output