# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for ELM configuration handling.
This module provides utility functions for configuration manipulation,
type coercion, and file I/O operations for the ELM system.
"""
from __future__ import annotations
import json
import os
from collections.abc import Mapping
from dataclasses import asdict, is_dataclass
from typing import Any, cast
import jax
from eformer.loggings import get_logger
from eformer.paths import ePath, ePathLike
from jax import numpy as jnp
from .types import DTypeLike, ELMConfig, PrecisionLike, TaskType
logger = get_logger(__name__)
[docs]def prune_nones(obj: Any) -> Any:
"""Recursively remove None values from nested data structures.
Args:
obj: Object to prune (dict, list, tuple, or any value)
Returns:
Object with None values removed from dicts and preserved structure
Example:
>>> data = {"a": 1, "b": None, "c": {"d": 2, "e": None}}
>>> prune_nones(data)
{'a': 1, 'c': {'d': 2}}
"""
if isinstance(obj, dict):
return {k: prune_nones(v) for k, v in obj.items() if v is not None}
if isinstance(obj, list | tuple):
t = type(obj)
return t(prune_nones(v) for v in obj)
return obj
[docs]def as_map(cfg: Any) -> dict[str, Any]:
"""Convert configuration object to dictionary.
Supports dataclasses and Mapping types, pruning None values from dataclasses.
Args:
cfg: Configuration object (dataclass or Mapping)
Returns:
Dictionary representation of the configuration
Raises:
TypeError: If cfg is not a dataclass or Mapping
Example:
>>> from dataclasses import dataclass
>>> @dataclass
... class Config:
... value: int = 1
... optional: str | None = None
>>> as_map(Config())
{'value': 1}
"""
if is_dataclass(cfg):
return cast(dict[str, Any], prune_nones(asdict(cfg)))
if isinstance(cfg, Mapping):
return dict(cfg)
raise TypeError(f"Unsupported config type: {type(cfg)!r}")
[docs]def deep_merge(base: dict[str, Any], overlay: dict[str, Any]) -> dict[str, Any]:
"""Deep merge two dictionaries, with overlay values taking precedence.
Recursively merges nested dictionaries. Non-dict values in overlay
replace corresponding values in base.
Args:
base: Base dictionary
overlay: Dictionary to merge into base
Returns:
New dictionary with merged values
Example:
>>> base = {"a": 1, "b": {"c": 2, "d": 3}}
>>> overlay = {"b": {"c": 20, "e": 4}, "f": 5}
>>> deep_merge(base, overlay)
{'a': 1, 'b': {'c': 20, 'd': 3, 'e': 4}, 'f': 5}
"""
out = dict(base)
for k, v in overlay.items():
if isinstance(v, dict) and isinstance(out.get(k), dict):
out[k] = deep_merge(out[k], v)
else:
out[k] = v
return out
[docs]def coerce_dtype(x: DTypeLike | None) -> jnp.dtype:
"""Convert dtype-like value to JAX dtype.
Supports string representations (e.g., "bf16", "fp32"), JAX dtypes,
and various FP8 formats. Returns float32 as default.
Args:
x: Dtype specification (string, jnp.dtype, or None)
Returns:
JAX dtype object
Example:
>>> coerce_dtype("bf16")
dtype('bfloat16')
>>> coerce_dtype("fp8_e4m3")
dtype('float8_e4m3')
>>> coerce_dtype(None)
dtype('float32')
"""
if x is None:
return jnp.float32
try:
return jnp.dtype(x)
except Exception:
s = str(x).lower()
fp8 = {
"fp8": jnp.float8_e5m2,
"float8": jnp.float8_e5m2,
"fp8_e4m3": jnp.float8_e4m3,
"float8_e4m3": jnp.float8_e4m3,
"fp8_e4m3fn": jnp.float8_e4m3fn,
"float8_e4m3fn": jnp.float8_e4m3fn,
"fp8_e4m3fnuz": jnp.float8_e4m3fnuz,
"float8_e4m3fnuz": jnp.float8_e4m3fnuz,
"fp8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
"float8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
"fp8_e3m4": jnp.float8_e3m4,
"float8_e3m4": jnp.float8_e3m4,
"fp8_e8m0fnu": jnp.float8_e8m0fnu,
"float8_e8m0fnu": jnp.float8_e8m0fnu,
}
if s in fp8:
return fp8[s]
if s in ("bf16", "bfloat16"):
return jnp.bfloat16
if s in ("fp16", "float16", "f16"):
return jnp.float16
if s in ("fp32", "float32", "f32"):
return jnp.float32
if s in ("fp64", "float64", "f64"):
return jnp.float64
return jnp.float32
[docs]def coerce_precision(p: PrecisionLike) -> jax.lax.Precision | None:
"""Convert precision-like value to JAX Precision.
Args:
p: Precision specification (string, jax.lax.Precision, or None)
Returns:
JAX Precision object or None
Example:
>>> coerce_precision("HIGH")
<Precision.HIGH: 1>
>>> coerce_precision(None)
None
"""
if p is None:
return None
if isinstance(p, jax.lax.Precision):
return p
return {
"DEFAULT": jax.lax.Precision.DEFAULT,
"HIGH": jax.lax.Precision.HIGH,
"HIGHEST": jax.lax.Precision.HIGHEST,
}.get(str(p).upper(), jax.lax.Precision.DEFAULT)
TASK_ALIASES: dict[str, TaskType] = {
"causal_lm": TaskType.CAUSAL_LM,
"lm": TaskType.CAUSAL_LM,
"seq2seq": TaskType.SEQUENCE_TO_SEQUENCE,
"sequence_to_sequence": TaskType.SEQUENCE_TO_SEQUENCE,
"speech_seq2seq": TaskType.SPEECH_SEQUENCE_TO_SEQUENCE,
"image_text_to_text": TaskType.IMAGE_TEXT_TO_TEXT,
"zero_shot_image_classification": TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION,
"diffusion_lm": TaskType.DIFFUSION_LM,
"base": TaskType.BASE_MODULE,
}
[docs]def normalize_task(t: TaskType | str | None) -> TaskType | None:
"""Normalize task type specification to TaskType enum.
Handles string aliases, case variations, and hyphen/underscore differences.
Args:
t: Task type specification (TaskType, string alias, or None)
Returns:
Normalized TaskType or None if not recognized
Example:
>>> normalize_task("causal-lm")
<TaskType.CAUSAL_LM: 'causal_lm'>
>>> normalize_task("LM")
<TaskType.CAUSAL_LM: 'causal_lm'>
"""
if t is None:
return None
if isinstance(t, TaskType):
return t
return TASK_ALIASES.get(str(t).strip().lower().replace("-", "_"))
[docs]def infer_task_from_hf_config(model_name_or_path: str) -> TaskType | None:
"""Infer task type from HuggingFace model config without downloading the model.
Fetches the config.json from HuggingFace Hub and determines the task type
based on the model architecture. Supports gated models through HF authentication.
Args:
model_name_or_path: HuggingFace model ID or local path
Returns:
Inferred TaskType, or None if unable to determine (will trigger fallback to CAUSAL_LM)
Example:
>>> infer_task_from_hf_config("meta-llama/Llama-2-7b")
<TaskType.CAUSAL_LM: 'causal-language-model'>
>>> infer_task_from_hf_config("Qwen/Qwen2-VL-7B")
<TaskType.IMAGE_TEXT_TO_TEXT: 'image-text-to-text'>
"""
try:
# Try loading from local path first
local_path = ePath(model_name_or_path)
if local_path.is_dir():
config_file = local_path / "config.json"
if config_file.exists():
config = json.loads(config_file.read_text())
else:
logger.warning(
f"No config.json found in local path: {model_name_or_path}. Task type will fallback to CAUSAL_LM."
)
return None
else:
# Try using huggingface_hub first (handles authentication for gated models)
try:
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
repo_id=model_name_or_path,
filename="config.json",
repo_type="model",
)
config = json.loads(ePath(config_path).read_text())
except Exception as hf_error:
# Fallback to requests (for non-gated models)
try:
import requests
except ImportError:
logger.warning(
f"Cannot fetch config for {model_name_or_path}: "
f"Neither huggingface_hub nor requests library available. "
f"Task type will fallback to CAUSAL_LM."
)
return None
config_url = f"https://huggingface.co/{model_name_or_path}/raw/main/config.json"
try:
response = requests.get(config_url, timeout=10)
response.raise_for_status()
config = response.json()
except requests.exceptions.RequestException as req_error:
# Check if it's a gated model (401 error)
if "401" in str(hf_error) or "gated" in str(hf_error).lower():
logger.warning(
f"Cannot access config for {model_name_or_path}: Model is gated and requires authentication. "
f"Run 'huggingface-cli login' to authenticate. Task type will fallback to CAUSAL_LM."
)
else:
logger.warning(
f"Failed to fetch config for {model_name_or_path}. "
f"Task type will fallback to CAUSAL_LM. Error: {req_error}"
)
return None
architectures = config.get("architectures", [])
model_type = config.get("model_type", "").lower()
if not architectures:
logger.warning(
f"No architectures found in config for {model_name_or_path}. Task type will fallback to CAUSAL_LM."
)
return None
arch = architectures[0]
if "ForCausalLM" in arch:
return TaskType.CAUSAL_LM
elif "ForConditionalGeneration" in arch:
if any(x in model_type for x in ["whisper", "speech2text"]):
return TaskType.SPEECH_SEQUENCE_TO_SEQUENCE
else:
return TaskType.IMAGE_TEXT_TO_TEXT
elif "ForSequenceClassification" in arch:
return TaskType.SEQUENCE_CLASSIFICATION
elif "ForAudioClassification" in arch:
return TaskType.AUDIO_CLASSIFICATION
elif "ForImageClassification" in arch:
return TaskType.IMAGE_CLASSIFICATION
elif any(x in arch for x in ["ForSpeechSeq2Seq", "Whisper"]):
return TaskType.SPEECH_SEQUENCE_TO_SEQUENCE
elif "ForZeroShotImageClassification" in arch:
return TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION
if "vision" in model_type or "clip" in model_type:
return TaskType.BASE_VISION
elif "diffusion" in model_type:
return TaskType.DIFFUSION_LM
logger.warning(
f"Could not map architecture '{arch}' to a TaskType for {model_name_or_path}. "
f"Task type will fallback to CAUSAL_LM."
)
return None
except Exception as e:
logger.warning(
f"Unexpected error inferring task for {model_name_or_path}: {e}. Task type will fallback to CAUSAL_LM."
)
return None
[docs]def save_elm_config(config: ELMConfig | Mapping[str, Any], json_file_path: str | os.PathLike | ePathLike) -> None:
"""Save an ELMConfig to a JSON file.
Args:
config: The ELMConfig or config dict to save
json_file_path: Path to the JSON file where the config will be saved
Example:
>>> config = {"model": {"name_or_path": "meta-llama/Llama-2-7b"}}
>>> save_elm_config(config, "my_config.json")
"""
from .normalizer import normalize
cfg = normalize(config)
json_path = ePath(json_file_path)
json_path.parent.mkdir(parents=True, exist_ok=True)
json_path.write_text(json.dumps(cfg, indent=2, ensure_ascii=False))
[docs]def load_elm_config(json_file_path: str | os.PathLike | ePathLike) -> ELMConfig:
"""Load an ELMConfig from a JSON file.
Args:
json_file_path: Path to the JSON file to load
Returns:
ELMConfig: The loaded and normalized configuration
Example:
>>> config = load_elm_config("my_config.json")
>>> model = build_model(config)
"""
from .normalizer import normalize
json_path = ePath(json_file_path)
if not json_path.exists():
raise FileNotFoundError(f"Config file not found: {json_file_path}")
raw_config = json.loads(json_path.read_text(encoding="utf-8"))
return normalize(raw_config)