# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import gc
import inspect
import os
import typing as tp
import warnings
import jax
import jax.extend
import numpy as np
from jax import dlpack
from jax import numpy as jnp
from tqdm.autonotebook import tqdm
from easydel.utils.helpers import check_bool_flag, get_logger
from .analyze_memory import SMPMemoryMonitor
from .traversals import flatten_dict, unflatten_dict
if tp.TYPE_CHECKING:
from transformers import PreTrainedModel
from easydel.infra.base_config import EasyDeLBaseConfig
from easydel.infra.base_module import EasyDeLBaseModule
mem_ops = SMPMemoryMonitor(5)
logger = get_logger(__name__)
EASYDEL_PERFRED_HOST_COPY_INDEX = int(os.getenv("EASYDEL_PERFRED_HOST_COPY_INDEX", "0"))
EASYDEL_PERFRED_HOST_COPY = str(os.getenv("EASYDEL_PERFRED_HOST_COPY", "cpu")).lower()
EASYDEL_PERFRED_HOST_COPY = None if EASYDEL_PERFRED_HOST_COPY == "none" else EASYDEL_PERFRED_HOST_COPY
[docs]class DtypeHandler:
"""Handles dtype conversions and operations."""
[docs] @staticmethod
def get_dtype(dtype: str | jnp.dtype) -> jnp.dtype:
"""Convert string dtype representation to JAX dtype."""
if isinstance(dtype, str):
dtype_map = {
"bf16": jnp.bfloat16,
"bfloat16": jnp.bfloat16,
"fp16": jnp.float16,
"float16": jnp.float16,
"fp32": jnp.float32,
"float32": jnp.float32,
"fp64": jnp.float64,
"float64": jnp.float64,
"fp8": jnp.float8_e5m2,
"fp8_e4m3fn": jnp.float8_e4m3fn,
"fp8_e4m3fnuz": jnp.float8_e4m3fnuz,
"fp8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
"fp8_e5m2": jnp.float8_e5m2,
"fp8_e5m2fnuz": jnp.float8_e5m2fnuz,
"float8_e4m3fn": jnp.float8_e4m3fn,
"float8_e4m3fnuz": jnp.float8_e4m3fnuz,
"float8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
"float8_e5m2": jnp.float8_e5m2,
"float8_e5m2fnuz": jnp.float8_e5m2fnuz,
}
dtype = dtype_map[dtype]
return dtype
[docs] @staticmethod
def float_tensor_to_dtype(tensor: tp.Any, dtype: str | jnp.dtype | None) -> tp.Any:
"""Convert float tensor to specified dtype."""
if dtype is None or dtype == "":
return tensor
dtype = DtypeHandler.get_dtype(dtype)
float_dtypes = (
jnp.bfloat16,
jnp.float16,
jnp.float32,
jnp.float64,
jnp.float8_e4m3fn,
jnp.float8_e4m3fnuz,
jnp.float8_e4m3b11fnuz,
jnp.float8_e5m2,
jnp.float8_e5m2fnuz,
)
if getattr(tensor, "dtype", None) in float_dtypes:
tensor = tensor.astype(dtype)
return tensor
[docs]class TensorConverter:
"""Handles tensor conversions between PyTorch and JAX."""
[docs] @staticmethod
def convert_pytorch_to_jnp(tensor: tp.Any, dtype: jnp.dtype) -> jnp.ndarray:
"""Convert PyTorch tensor to JAX array."""
if "bfloat16" in str(tensor.dtype):
tensor = tensor.float()
npv = tensor.cpu().detach().numpy()
return jnp.array(npv, dtype=dtype)
[docs] @staticmethod
@functools.lru_cache
def get_torch():
"""Import and return torch module (cached)."""
import torch
return torch
[docs] @staticmethod
def jax_to_pytorch(x: jax.Array) -> tp.Any:
"""Convert JAX array to PyTorch tensor."""
if check_bool_flag("EASY_SAFE_TRANSFER", True):
x = jax.device_get(x)
return TensorConverter.get_torch().from_numpy(np.array(x.tolist(), dtype=x.dtype))
else:
from torch import cuda
from torch.utils import dlpack as dlpack_pt
platform = jax.extend.backend.get_backend()
cpu_force = not cuda.is_available()
if (
platform in ["cpu", "gpu"]
and not cpu_force
and not check_bool_flag("EASYDEL_FORCE_TORCH_USE_CPU", False)
):
dl_pack_jax = dlpack.to_dlpack(
x,
stream=True if (platform == "gpu" and not cpu_force) else None,
src_device=next(iter(x.devices())),
)
else:
dl_pack_jax = dlpack.to_dlpack(
jax.device_put(
jax.device_get(x),
jax.devices(EASYDEL_PERFRED_HOST_COPY)[EASYDEL_PERFRED_HOST_COPY_INDEX],
),
stream=None,
)
return dlpack_pt.from_dlpack(dl_pack_jax)
[docs] @staticmethod
def pytorch_to_jax(x: tp.Any) -> jnp.ndarray:
"""Convert PyTorch tensor to JAX array."""
return jnp.asarray(x.detach().cpu().numpy())
[docs]class StateDictConverter:
"""Handles conversion between PyTorch and EasyDeL state dictionaries."""
[docs] @staticmethod
def match_keywords(string: str, required: list[str], forbidden: list[str]) -> bool:
"""Check if string contains all required keywords and none of the forbidden ones."""
return all(t in string for t in required) and not any(n in string for n in forbidden)
[docs] @staticmethod
def process_tensor(key: str, tensor: tp.Any, config: dict[str, tp.Any]) -> list[tuple[tuple, jnp.ndarray]] | None:
"""Process a single tensor and return its processed key and value."""
new_key = key
reform_param = config.get("reform_param", None)
if reform_param:
sorted_items = sorted(reform_param.items(), key=lambda x: len(x[0]), reverse=True)
for key_check, value in sorted_items:
anchor_to_end = key_check.endswith("$")
match_target = key_check[:-1] if anchor_to_end else key_check
match_index = key.find(match_target)
if match_index != -1:
after_match = key[match_index + len(match_target) :]
if anchor_to_end and after_match:
continue
if not after_match or after_match.startswith("."):
before_match = key[:match_index]
if not before_match or before_match.endswith("."):
splits = value["splits"]
results = []
new_config = config.copy()
new_config["reform_param"] = {}
for split in splits:
split_name = split["name"]
spliter = split["spliter"]
new_key_split = f"{before_match}{split_name}{after_match}"
tensor_split = spliter(tensor)
sub_results = StateDictConverter.process_tensor(
new_key_split,
tensor_split,
new_config,
)
if sub_results:
results.extend(sub_results)
return results
if any(layer_name in key for layer_name in config["embedding_layer_names"]):
new_key = f"{key[: -len('.weight')]}.embedding"
elif any(layer_norm in key for layer_norm in config["layernorm_names"]):
new_key = key.replace(".weight", ".scale")
elif "weight" in key:
is_moe_expert = key in config.get("consolidated_moe_keys", set())
ndim = len(tensor.shape)
if not is_moe_expert:
if ndim == 2:
tensor = tensor.permute(1, 0)
elif ndim == 3:
tensor = tensor.permute(2, 1, 0)
elif ndim == 4:
tensor = tensor.permute(2, 3, 1, 0)
elif ndim == 5:
tensor = tensor.permute(2, 3, 4, 1, 0)
elif ndim == 6:
tensor = tensor.permute(4, 5, 3, 2, 1, 0)
else:
if ndim == 3:
tensor = tensor.permute(0, 2, 1)
new_key = key.replace(".weight", ".kernel")
key_tuple = tuple(int(n) if n.isdigit() else n for n in new_key.split("."))
if config["uses_tie_word_embedding"] and config["lm_head_name"] and key_tuple[0] == config["lm_head_name"]:
return None
array = TensorConverter.convert_pytorch_to_jnp(tensor, config["dtype"])
return [(key_tuple, array)]
@staticmethod
def _base_huggingface_to_easydel(
state_dict: dict[str, tp.Any],
*,
device: jax.Device | None = None, # type:ignore
embedding_layer_names: list[str] | None = None,
layernorm_names: list[str] | None = None,
moe_block_names: list[str] | None = None,
moe_names: list[str] | None = None,
shard_fns: tp.Mapping[tuple, tp.Callable] | None = None,
dtype: jnp.dtype = jnp.float16,
verbose: bool = True,
callback: tp.Callable[[jax.Array, tuple], jax.Array] | None = None,
remove_state_dict: bool = False,
lm_head_name: str | None = None,
uses_tie_word_embedding: bool = False,
consolidated_moe_keys: set[str] | None = None,
reform_param: dict | None = None,
**kwargs,
) -> dict[str, tp.Any]:
"""Base conversion function from PyTorch state dict to EasyDeL format."""
try:
import torch
_clear = torch.cuda.empty_cache if torch.cuda.is_available() else gc.collect
except ModuleNotFoundError:
_clear = gc.collect
config = {
"embedding_layer_names": set(embedding_layer_names or []),
"layernorm_names": set(layernorm_names or []),
"moe_block_names": set(moe_block_names or []),
"moe_names": set(moe_names or []),
"lm_head_name": lm_head_name,
"uses_tie_word_embedding": uses_tie_word_embedding,
"dtype": dtype,
"consolidated_moe_keys": consolidated_moe_keys or set(),
"reform_param": reform_param,
}
with jax.default_device(device) if device is not None and shard_fns is None else contextlib.nullcontext():
flax_dict = {}
with tqdm(total=len(state_dict), disable=not verbose, desc="Converting Model") as pbar:
keys = sorted(state_dict.keys())
for key in keys:
tensor = state_dict.get(key)
try:
bytesi = {
i: jax.local_devices()[i].memory_stats()["bytes_in_use"]
for i in range(jax.local_device_count())
}
results = StateDictConverter.process_tensor(key, tensor, config)
if results is not None:
for key_tuple, jax_array in results:
if shard_fns and key_tuple in shard_fns:
jax_array = shard_fns[key_tuple](jax_array)
if callback is not None:
jax_array = callback(jax_array, key_tuple)
bytesn = {
i: jax.local_devices()[i].memory_stats()["bytes_in_use"]
for i in range(jax.local_device_count())
}
change = {i: bytesn[i] - bytesi[i] for i in range(jax.local_device_count())}
divider = 1024**3
change_gb = {i: round(change[i] / divider, 4) for i in change}
usage_gb = {i: round(bytesn[i] / divider, 4) for i in bytesn}
strm = f"Sharding {'.'.join([str(i) for i in key_tuple])} change_gb: {change_gb} current_gb: {usage_gb}"
logger.debug(strm)
flax_dict[key_tuple] = jax_array
except Exception as e:
logger.error(f"Error processing key {key}: {e!s}")
pbar.update(1)
if remove_state_dict:
del state_dict
_clear()
return unflatten_dict(flax_dict)
[docs] @staticmethod
def huggingface_to_easydel(
state_dict: dict[str, tp.Any],
*,
device: jax.Device | None = None, # type:ignore
embedding_layer_names: list[str] | None = None,
layernorm_names: list[str] | None = None,
moe_block_names: list[str] | None = None,
moe_names: list[str] | None = None,
moe_block_path: list[str] | None = None,
moe_path: list[str] | None = None,
shard_fns: tp.Mapping[tuple, tp.Callable] | None = None,
dtype: jnp.dtype = jnp.float16,
verbose: bool = True,
callback: tp.Callable[[jax.Array, tuple], jax.Array] | None = None,
remove_state_dict: bool = False,
lm_head_name: str | None = None,
uses_tie_word_embedding: bool = False,
reform_param: dict | None = None,
**kwargs,
) -> dict[str, tp.Any]:
"""Convert PyTorch state dict to EasyDeL format with MoE transformations."""
consolidated_moe_keys = set()
if moe_block_names is not None and moe_names is not None:
state_dict, consolidated_moe_keys = StateDictConverter.apply_moe_transformations(
state_dict=state_dict,
moe_names=moe_names,
moe_path=moe_path,
moe_block_names=moe_block_names,
moe_block_path=moe_block_path,
)
return StateDictConverter._base_huggingface_to_easydel(
state_dict,
device=device,
embedding_layer_names=embedding_layer_names,
layernorm_names=layernorm_names,
moe_names=moe_names,
moe_path=moe_path,
moe_block_names=moe_block_names,
moe_block_path=moe_block_path,
shard_fns=shard_fns,
dtype=dtype,
verbose=verbose,
callback=callback,
remove_state_dict=remove_state_dict,
lm_head_name=lm_head_name,
uses_tie_word_embedding=uses_tie_word_embedding,
consolidated_moe_keys=consolidated_moe_keys,
reform_param=reform_param,
**kwargs,
)
[docs] @staticmethod
def easydel_to_torch(module: EasyDeLBaseModule, dtype: jnp.dtype = jnp.float16, **kwargs) -> dict[str, tp.Any]:
"""Convert EasyDeL module to PyTorch state dict."""
if dtype is None:
dtype = module.param_dtype
graphtree = unflatten_dict(module.parameters)
model_parameters = flatten_dict(graphtree, sep=".")
from easydel.layers.moe import BaseMoeModule, ParallelMoELinear
from easydel.utils import traversals
md = ParallelMoELinear
moe_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(module, md)]
md = BaseMoeModule
moe_block_path = [".".join(tuple(map(str, pa))) for pa, _ in traversals.iter_module_search(module, md)]
moe_names = list(set([names.split(".")[-1] for names in moe_path])) if moe_path else None
moe_block_names = list(set([names.split(".")[-1] for names in moe_block_path])) if moe_block_path else None
stacked_moe_keys = set()
if moe_block_names and moe_names and moe_block_path:
for block_path in moe_block_path:
for moe_name in moe_names:
potential_key = f"{block_path}.experts.{moe_name}.kernel"
if potential_key in model_parameters:
stacked_moe_keys.add(potential_key)
torch_state_dict = {}
with tqdm(model_parameters.items(), desc=f"Converting {module.__class__.__name__} to torch") as pbar:
for key, tensor in pbar:
if tensor is None:
continue
if hasattr(tensor, "materialize"):
tensor = tensor.materialize()
if hasattr(tensor, "value") and hasattr(tensor.value, "materialize"):
tensor = tensor.value.materialize()
if tensor.dtype != DtypeHandler.get_dtype(dtype):
tensor = tensor.astype(DtypeHandler.get_dtype(dtype))
tensor = TensorConverter.jax_to_pytorch(jax.block_until_ready(tensor))
is_stacked_moe = key in stacked_moe_keys
if key.endswith(".kernel"):
if not is_stacked_moe:
if tensor.ndim == 2:
tensor = tensor.permute(1, 0)
elif tensor.ndim == 3:
tensor = tensor.permute(2, 1, 0)
elif tensor.ndim == 4:
tensor = tensor.permute(3, 2, 0, 1)
elif tensor.ndim == 5:
tensor = tensor.permute(4, 3, 0, 1, 2)
elif tensor.ndim == 6:
tensor = tensor.permute(5, 4, 3, 2, 0, 1)
else:
if tensor.ndim == 3:
tensor = tensor.permute(0, 2, 1)
key = key.replace(".kernel", ".weight").replace(".embedding", ".weight").replace(".scale", ".weight")
torch_state_dict[key] = tensor
if moe_block_names and moe_names and moe_block_path and moe_path:
torch_state_dict = StateDictConverter.apply_moe_transformations_reverse(
state_dict=torch_state_dict,
moe_names=moe_names,
moe_path=moe_path,
moe_block_names=moe_block_names,
moe_block_path=moe_block_path,
)
reform_param = kwargs.get("reform_param", None)
if reform_param:
for key_check, value_check in reform_param.items():
inverse_spliter = value_check.get("inverse_spliter", None)
if inverse_spliter:
anchor_to_end = key_check.endswith("$")
match_target = key_check[:-1] if anchor_to_end else key_check
candidates = {} # (prefix, suffix) -> {split_name: tensor}
splits = value_check["splits"]
split_names = [s["name"] for s in splits]
keys_to_remove = []
for key in torch_state_dict.keys():
for split_name in split_names:
match_index = key.find(split_name)
if match_index != -1:
after_match = key[match_index + len(split_name) :]
if anchor_to_end and after_match:
continue
if not after_match or after_match.startswith("."):
before_match = key[:match_index]
if not before_match or before_match.endswith("."):
original_key_candidate = f"{before_match}{match_target}{after_match}"
if original_key_candidate.replace(match_target, split_name) == key:
prefix = before_match
suffix = after_match
group_key = (prefix, suffix)
if group_key not in candidates:
candidates[group_key] = {}
candidates[group_key][split_name] = key
for (prefix, suffix), found_splits in candidates.items():
if len(found_splits) == len(split_names):
tensors_to_merge = []
for split in splits:
split_name = split["name"]
key = found_splits[split_name]
tensors_to_merge.append(torch_state_dict[key])
keys_to_remove.append(key)
torch_module = TensorConverter.get_torch()
positional_params = [
p
for p in inspect.signature(inverse_spliter).parameters.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
wants_torch = (
len(positional_params) > len(tensors_to_merge) and positional_params[0].name == "torch"
)
if wants_torch:
merged_tensor = inverse_spliter(torch_module, *tensors_to_merge)
else:
merged_tensor = inverse_spliter(*tensors_to_merge)
original_key = f"{prefix}{match_target}{suffix}"
torch_state_dict[original_key] = merged_tensor
for key in keys_to_remove:
del torch_state_dict[key]
return torch_state_dict
[docs]class ModelConverter:
"""Handles model conversions between EasyDeL and HuggingFace formats."""
[docs] @staticmethod
def easydel_to_huggingface(
module: EasyDeLBaseModule,
config: EasyDeLBaseConfig,
base_huggingface_module: PreTrainedModel,
base_huggingface_module_kwarguments: dict | None = None,
dtype: jnp.dtype = jnp.float16,
use_meta_torch: bool = True,
reform_param: dict | None = None,
**kw,
) -> tp.Any:
"""Convert EasyDeL module to HuggingFace model."""
import torch
if base_huggingface_module_kwarguments is None:
base_huggingface_module_kwarguments = {}
state_dict = StateDictConverter.easydel_to_torch(module=module, dtype=dtype, reform_param=reform_param)
base_config = base_huggingface_module.config_class.from_dict(config.to_dict())
with torch.device("meta") if use_meta_torch else contextlib.nullcontext():
model: torch.nn.Module = base_huggingface_module(config=base_config, **base_huggingface_module_kwarguments)
key_shape_checks = {k: v.shape for k, v in model.state_dict().items() if hasattr(v, "shape")}
if len(list(key_shape_checks.keys())) != len(list(state_dict.keys())):
warnings.warn("There might be an issue with converted `state_dict`.", stacklevel=1)
for key, shape in key_shape_checks.items():
if state_dict[key].shape != shape:
warnings.warn(f"Shape conflict at {key}.", stacklevel=1)
model.load_state_dict(state_dict, assign=True, strict=True)
return model