# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""eLargeModel - Easy Large Models master class for EasyDeL.
This module provides a unified interface for working with large language models
in the EasyDeL framework, combining configuration management, model building,
training orchestration, and inference engine initialization.
Key Features:
- Unified configuration management for models, training, and inference
- Automatic model and tokenizer initialization from HuggingFace or local paths
- Support for multiple training paradigms (SFT, DPO, ORPO, GRPO, distillation)
- Integration with the eSurge inference engine
- Built-in evaluation with lm-evaluation-harness
- Flexible dataset mixture configuration
- Model sharding and quantization support
"""
from __future__ import annotations
import json
import os
import pprint
import typing
from collections.abc import Mapping
from typing import Any, NotRequired, Unpack
from eformer.loggings import get_logger
from eformer.paths import ePath, ePathLike
from transformers import AutoTokenizer
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.base_state import EasyDeLState
from easydel.infra.factory import TaskType
from easydel.trainers.training_configurations import TrainingArguments
from .builders import (
build_dataset,
build_esurge,
build_model,
build_sharded_source,
to_data_mixture_kwargs,
to_esurge_kwargs,
to_from_pretrained_kwargs,
)
from .normalizer import materialize_base_config, normalize, resolve_task, validate
from .trainer_types import get_trainer_class, get_training_arguments_class, normalize_trainer_config
from .types import ELMConfig
from .utils import load_elm_config, save_elm_config
if typing.TYPE_CHECKING:
from datasets import Dataset
from easydel.data.core.protocols import ShardedDataSource
from easydel.inference import eSurge
from easydel.trainers import Trainer
logger = get_logger("eLargeModel")
[docs]class BuildTrainerKws(typing.TypedDict, total=False):
"""Type hints for optional keyword arguments when building trainers.
Attributes:
data_collator: Custom data collator for batching examples
formatting_func: Function to format examples for SFT training
reward_processing_classes: Processing classes for reward models in GRPO
data_tokenize_fn: Custom tokenization function for data preprocessing
reference_model: Reference model for DPO/preference optimization
reward_model: Reward model for GRPO training
teacher_model: Teacher model for distillation training
reward_funcs: Custom reward functions for GRPO
"""
data_collator: NotRequired[typing.Callable]
formatting_func: NotRequired[typing.Callable]
reward_processing_classes: NotRequired[list[typing.Callable]]
data_tokenize_fn: NotRequired[typing.Callable]
reference_model: NotRequired[EasyDeLBaseModule | None]
reward_model: NotRequired[EasyDeLBaseModule | None]
teacher_model: NotRequired[EasyDeLBaseModule | None]
reward_funcs: NotRequired[Any | None]
[docs]class eLargeModel:
"""Master class for Easy Large Models (ELM) in EasyDeL.
This class provides a unified interface for:
- Configuration management (load, save, create)
- Model building and initialization (including teacher/reference models)
- Training orchestration with multiple paradigms (SFT, DPO, ORPO, etc.)
- eSurge inference engine integration
- Tokenizer management
- Dataset mixture configuration
- Model evaluation with lm-evaluation-harness
Attributes:
config: The normalized ELM configuration dictionary
model_name: The model name or path from configuration
task: The resolved task type (auto-detected or specified)
teacher_model_name: Teacher model name for distillation (if configured)
reference_model_name: Reference model name for DPO/ORPO (if configured)
Example:
Basic model loading:
>>> elm = eLargeModel({"model": {"name_or_path": "meta-llama/Llama-2-7b"}})
>>> model = elm.build_model()
From pretrained with configuration:
>>> elm = eLargeModel.from_pretrained(
... "meta-llama/Llama-2-7b",
... task="causal-lm"
... )
>>> elm.set_dtype("bf16")
>>> elm.set_sharding(axis_dims=(1, 2, 1, -1))
Loading from JSON configuration:
>>> elm = eLargeModel.from_json("config.json")
>>> esurge_engine = elm.build_esurge()
Training with SFT:
>>> elm.set_trainer("sft", learning_rate=2e-5, num_train_epochs=3)
>>> elm.add_dataset("train.json", dataset_type="json", content_field="text")
>>> results = elm.train()
Evaluation:
>>> results = elm.eval(["hellaswag", "mmlu"], engine="esurge")
"""
def __init__(self, config: ELMConfig | Mapping[str, Any] | str | os.PathLike | ePathLike | None = None):
"""Initialize eLargeModel with configuration.
Args:
config: Can be:
- ELMConfig or dict with configuration
- Path to JSON configuration file
- None to create empty configuration
"""
if config is None:
self._config = normalize({"model": {"name_or_path": ""}})
elif isinstance(config, str | os.PathLike) or hasattr(config, "__fspath__"):
self._config = load_elm_config(config)
else:
self._config = normalize(config)
self._model = None
self._tokenizer = None
[docs] @classmethod
def from_json(cls, json_path: str | os.PathLike | ePathLike) -> eLargeModel:
"""Create eLargeModel from JSON configuration file.
Args:
json_path: Path to JSON configuration file
Returns:
eLargeModel instance
"""
return cls(load_elm_config(json_path))
[docs] @classmethod
def from_pretrained(
cls,
model_name_or_path: str,
task: TaskType | str | None = None,
**kwargs,
) -> eLargeModel:
"""Create eLargeModel from pretrained model name or path.
Args:
model_name_or_path: HuggingFace model ID or local path
task: Optional task type (auto-detected if not provided or AUTO_BIND)
**kwargs: Additional configuration options
Returns:
eLargeModel instance with configuration
"""
from .utils import infer_task_from_hf_config, normalize_task
# Auto-detect task if None or AUTO_BIND
if task is None or task == TaskType.AUTO_BIND or task == "auto-bind":
inferred_task = infer_task_from_hf_config(model_name_or_path)
if inferred_task is not None:
task = inferred_task
logger.info(f"Auto-detected task type: {task.value}")
else:
normalized = normalize_task(task)
if normalized is not None:
task = normalized
config = {
"model": {
"name_or_path": model_name_or_path,
**({"task": task} if task else {}),
},
**kwargs,
}
return cls(config)
@property
def config(self) -> ELMConfig:
"""Get the normalized configuration dictionary.
Returns:
The full ELM configuration including model, loader, sharding,
quantization, training, and inference settings.
"""
return self._config
@property
def model_name(self) -> str:
"""Get the model name or path.
Returns:
The HuggingFace model ID or local path to the model.
"""
return self._config["model"]["name_or_path"]
@property
def task(self) -> TaskType:
"""Get the resolved task type.
Returns:
The task type (e.g., TaskType.CAUSAL_LM) either explicitly
configured or auto-detected from the model.
"""
return resolve_task(self._config)
@property
def teacher_model_name(self) -> str | None:
"""Get the teacher model name or path for distillation.
Returns:
The teacher model path if configured, None otherwise.
"""
return self._config.get("teacher_model", {}).get("name_or_path")
@property
def reference_model_name(self) -> str | None:
"""Get the reference model name or path for DPO/ORPO.
Returns:
The reference model path if configured, None otherwise.
"""
return self._config.get("reference_model", {}).get("name_or_path")
[docs] def update_config(self, updates: Mapping[str, Any]) -> eLargeModel:
"""Update configuration with new values.
Performs a deep merge of the updates into the existing configuration,
preserving nested structures. The configuration is re-normalized after
updating to ensure consistency.
Args:
updates: Dictionary with configuration updates. Can include nested
structures like {"model": {"dtype": "bf16"}, "esurge": {"max_num_seqs": 32}}
Returns:
Self for method chaining
Example:
>>> elm.update_config({
... "loader": {"dtype": "bf16"},
... "esurge": {"max_model_len": 4096}
... })
"""
from .utils import deep_merge
self._config = normalize(deep_merge(self._config, updates))
return self
[docs] def set_model(self, model_name_or_path: str) -> eLargeModel:
"""Set the model name or path.
Updates the primary model configuration. This will clear any cached
model instance to ensure the new model is loaded on next build.
Args:
model_name_or_path: HuggingFace model ID (e.g., "meta-llama/Llama-2-7b")
or local path to model directory
Returns:
Self for method chaining
Example:
>>> elm.set_model("meta-llama/Llama-2-7b-hf")
>>> elm.set_model("/path/to/local/model")
"""
self._config["model"]["name_or_path"] = model_name_or_path
return self
[docs] def set_teacher_model(self, model_name_or_path: str) -> eLargeModel:
"""Set the teacher model name or path for distillation training.
Configures a teacher model used for knowledge distillation. The teacher
model is typically a larger, more capable model that guides the training
of the student (primary) model.
Args:
model_name_or_path: HuggingFace model ID or local path for teacher model.
Should be a model compatible with the student model's architecture.
Returns:
Self for method chaining
Example:
>>> elm.set_model("meta-llama/Llama-2-7b") # Student model
>>> elm.set_teacher_model("meta-llama/Llama-2-13b") # Teacher model
>>> elm.set_trainer("distillation", temperature=3.0)
"""
if "teacher_model" not in self._config:
self._config["teacher_model"] = {}
self._config["teacher_model"]["name_or_path"] = model_name_or_path
return self
[docs] def set_reference_model(self, model_name_or_path: str) -> eLargeModel:
"""Set the reference model name or path for preference optimization.
Configures a reference model used in DPO (Direct Preference Optimization)
and similar preference-based training methods. The reference model provides
a baseline for computing preference losses.
Args:
model_name_or_path: HuggingFace model ID or local path for reference model.
Often the same as the base model before fine-tuning.
Returns:
Self for method chaining
Example:
>>> elm.set_model("meta-llama/Llama-2-7b-hf") # Model to train
>>> elm.set_reference_model("meta-llama/Llama-2-7b-hf") # Reference
>>> elm.set_trainer("dpo", beta=0.1)
"""
if "reference_model" not in self._config:
self._config["reference_model"] = {}
self._config["reference_model"]["name_or_path"] = model_name_or_path
return self
[docs] def set_dtype(self, dtype: str) -> eLargeModel:
"""Set the data type for model loading.
Configures both the computation dtype and parameter dtype for the model.
This affects memory usage and computation speed.
Args:
dtype: Data type string. Supported values:
- "bf16": BFloat16 (recommended for TPU, modern GPUs)
- "fp16": Float16 (good for older GPUs)
- "fp32": Float32 (highest precision, most memory)
- "fp8": Float8 (experimental, requires compatible hardware)
Returns:
Self for method chaining
Example:
>>> elm.set_dtype("bf16") # Use bfloat16 for training/inference
"""
self._config.setdefault("loader", {})["dtype"] = dtype
self._config["loader"]["param_dtype"] = dtype
return self
[docs] def set_sharding(
self,
axis_dims: tuple[int, ...] | None = None,
axis_names: tuple[str, ...] | None = None,
**kwargs,
) -> eLargeModel:
"""Configure model sharding for distributed training/inference.
Sets up model parallelism by specifying how to shard model parameters
and computations across devices. Essential for training large models
that don't fit on a single device.
Args:
axis_dims: Sharding axis dimensions as a tuple. Common patterns:
- (1, 1, 1, -1): Data parallel only
- (2, 1, 1, -1): 2-way tensor parallel
- (1, 2, 1, -1): 2-way pipeline parallel
- (2, 2, 1, -1): 2-way tensor + 2-way pipeline parallel
axis_names: Sharding axis names (e.g., ("dp", "tp", "pp", "sp"))
- dp: Data parallel
- tp: Tensor parallel
- pp: Pipeline parallel
- sp: Sequence parallel
**kwargs: Additional sharding options.
Returns:
Self for method chaining
Example:
>>> # 2-way tensor parallel, 2-way data parallel
>>> elm.set_sharding(
... axis_dims=(2, 2, 1, -1),
... axis_names=("dp", "tp", "pp", "sp")
... )
"""
sharding = self._config.setdefault("sharding", {})
if axis_dims is not None:
sharding["axis_dims"] = axis_dims
if axis_names is not None:
sharding["axis_names"] = axis_names
sharding.update(kwargs)
return self
[docs] def set_quantization(
self,
method: str | None = None,
block_size: int = 128,
**kwargs,
) -> eLargeModel:
"""Configure quantization settings.
Enables model quantization to reduce memory usage and potentially
improve inference speed at the cost of some accuracy.
Args:
method: Quantization method.
block_size: Quantization block size (default: 128).
Smaller blocks = better accuracy but more overhead.
**kwargs: Additional quantization options:
- platform: Target platform ("cpu", "cuda", "tpu")
- compute_dtype: Dtype for computation (e.g., "fp16")
- double_quant: Enable double quantization for 4bit
Returns:
Self for method chaining
Example:
>>> elm.set_quantization("nf4", block_size=64)
>>> elm.set_quantization("a8bit")
"""
quant = self._config.setdefault("quantization", {})
if method is not None:
quant["method"] = method
quant["block_size"] = block_size
quant.update(kwargs)
return self
[docs] def set_operation_configs(
self,
configs: Mapping[str, Any] | None = None,
**kwargs,
) -> eLargeModel:
"""Configure ejkernel operation overrides.
Allows overriding ejkernel's autotune behavior for specific attention
operations by providing explicit configuration objects. When a config
is provided, it's passed directly to the operation instead of using
ejkernel's autotune.
Args:
configs: Dictionary mapping operation names to config objects.
Valid operation names (must match OperationRegistry):
- "flash_attn2": Flash attention 2
- "ring": Ring attention
- "blocksparse": Block sparse attention
- "ragged_page_attention_v2": Ragged page attention v2
- "ragged_page_attention_v3": Ragged page attention v3
- "sdpa": Scaled dot product attention
- "vanilla": Vanilla attention
**kwargs: Individual operation configs as keyword arguments.
These are merged with the configs dict.
Returns:
Self for method chaining
Example:
>>> from easydel import FlashAttentionConfig, RingAttentionConfig
>>>
>>> # Using dict
>>> elm.set_operation_configs({
... "flash_attn2": FlashAttentionConfig(platform="triton"),
... "ring": RingAttentionConfig(),
... })
>>>
>>> # Using kwargs
>>> elm.set_operation_configs(
... flash_attn2=FlashAttentionConfig(platform="pallas"),
... )
"""
base_cfg = self._config.setdefault("base_config", {})
op_configs = base_cfg.setdefault("operation_configs", {})
if configs is not None:
op_configs.update(configs)
op_configs.update(kwargs)
return self
[docs] def set_esurge(
self,
max_model_len: int | None = None,
max_num_seqs: int = 16,
hbm_utilization: float = 0.85,
**kwargs,
) -> eLargeModel:
"""Configure eSurge inference settings.
eSurge is a high-performance batch inference engine optimized for
throughput. It uses PagedAttention for efficient memory management.
Args:
max_model_len: Maximum model sequence length (input + output tokens).
If None, uses model's default max position embeddings.
max_num_seqs: Maximum number of sequences to process concurrently.
Higher values increase throughput but require more memory.
hbm_utilization: HBM memory utilization ratio (0.0-1.0).
Controls how much device memory to use for KV cache.
**kwargs: Additional eSurge options:
- page_size: PagedAttention page size (default: 128)
- enable_prefix_caching: Enable prefix caching optimization
- kv_cache_dtype: Dtype for KV cache (None = auto)
- decoding_engine: "ring" or "triton" (default: auto)
Returns:
Self for method chaining
Example:
>>> elm.set_esurge(
... max_model_len=8192,
... max_num_seqs=64,
... hbm_utilization=0.9,
... enable_prefix_caching=True
... )
"""
esurge = self._config.setdefault("esurge", {})
if max_model_len is not None:
esurge["max_model_len"] = max_model_len
esurge["max_num_seqs"] = max_num_seqs
esurge["hbm_utilization"] = hbm_utilization
esurge.update(kwargs)
return self
[docs] def set_mixture(
self,
informs: list[dict] | None = None,
batch_size: int = 32,
streaming: bool = True,
use_fast_loader: bool = True,
**kwargs,
) -> eLargeModel:
"""Configure data mixture settings for training/evaluation.
Sets up a mixture of datasets that can be combined and sampled from
during training. Supports multiple data sources and formats.
Args:
informs: List of dataset configurations. Each dict should contain:
- type: Dataset type ("json", "parquet", "csv", "text", or HF dataset ID)
- data_files: Path or pattern to data files
- content_field: Field name containing the text content
- split: Dataset split to use (default: "train")
- weight: Sampling weight for this dataset (optional)
batch_size: Batch size for data loading (default: 32)
streaming: Use streaming mode for large datasets (default: True).
Reduces memory usage but may be slower.
use_fast_loader: Enable fast loading with fsspec (default: True).
Provides optimized loading for remote/cloud storage.
**kwargs: Additional mixture options:
- max_samples: Maximum samples per dataset
- shuffle: Whether to shuffle data
- seed: Random seed for shuffling
Returns:
Self for method chaining
Example:
>>> elm.set_mixture(
... informs=[
... {"type": "json", "data_files": "train.json", "content_field": "text", "weight": 0.7},
... {"type": "parquet", "data_files": "valid/*.parquet", "content_field": "content", "weight": 0.3}
... ],
... batch_size=32,
... streaming=True,
... shuffle=True,
... seed=42
... )
"""
mixture = self._config.setdefault("mixture", {})
if informs is not None:
mixture["informs"] = informs
mixture["batch_size"] = batch_size
mixture["streaming"] = streaming
mixture["use_fast_loader"] = use_fast_loader
mixture.update(kwargs)
return self
[docs] def add_dataset(
self,
data_files: str | list[str],
dataset_type: str | None = None,
content_field: str = "content",
split: str = "train",
**kwargs,
) -> eLargeModel:
"""Add a dataset to the mixture configuration.
Appends a new dataset to the existing mixture. Multiple datasets
can be added and will be combined during training.
Args:
data_files: Path(s) to data files. Can be:
- Single file: "data.json"
- Multiple files: ["data1.json", "data2.json"]
- Glob pattern: "data/*.parquet"
- Remote URL: "https://example.com/data.json"
dataset_type: Dataset type or format. Options:
- File formats: "json", "jsonl", "parquet", "csv", "text"
- HuggingFace dataset ID: "imdb", "squad", etc.
- None: Auto-detect from file extension
content_field: Field name containing the text content (default: "content").
For chat data, might be "messages" or "conversations".
split: Dataset split to use (default: "train").
Common values: "train", "validation", "test".
**kwargs: Additional dataset options:
- weight: Sampling weight for this dataset
- max_samples: Maximum samples to use
- filter_fn: Function to filter samples
- map_fn: Function to transform samples
Returns:
Self for method chaining
Example:
>>> # Add a JSON dataset
>>> elm.add_dataset("train.json", dataset_type="json", content_field="text")
>>>
>>> # Add a HuggingFace dataset
>>> elm.add_dataset("imdb", dataset_type="imdb", split="train")
>>>
>>> # Add multiple Parquet files with sampling weight
>>> elm.add_dataset(
... "data/*.parquet",
... dataset_type="parquet",
... content_field="content",
... weight=0.5
... )
"""
mixture = self._config.setdefault("mixture", {})
informs = mixture.setdefault("informs", [])
inform = {
"data_files": data_files,
"content_field": content_field,
"split": split,
**kwargs,
}
if dataset_type:
inform["type"] = dataset_type
informs.append(inform)
return self
[docs] def set_eval(
self,
max_new_tokens: int = 2048,
temperature: float = 0.0,
top_p: float = 0.95,
batch_size: int | None = None,
use_tqdm: bool = True,
**kwargs,
) -> eLargeModel:
"""Configure evaluation settings for lm-evaluation-harness.
Sets default parameters for model evaluation on standard benchmarks.
These settings apply when using the eval() method.
Args:
max_new_tokens: Maximum tokens to generate per evaluation sample
(default: 2048). Lower values speed up evaluation.
temperature: Sampling temperature (default: 0.0 for greedy decoding).
0.0 = deterministic/greedy, higher = more random.
top_p: Top-p (nucleus) sampling parameter (default: 0.95).
Only used when temperature > 0.
batch_size: Evaluation batch size (default: engine-specific).
Higher values increase throughput but use more memory.
use_tqdm: Show progress bar during evaluation (default: True)
**kwargs: Additional evaluation options:
- top_k: Top-k sampling parameter
- repetition_penalty: Penalty for repeated tokens
- num_beams: Beam search width (1 = greedy)
- do_sample: Whether to use sampling
- early_stopping: Stop generation at first EOS
Returns:
Self for method chaining
Example:
>>> # Configure for deterministic evaluation
>>> elm.set_eval(
... max_new_tokens=512,
... temperature=0.0,
... batch_size=64
... )
>>>
>>> # Configure for sampling-based evaluation
>>> elm.set_eval(
... temperature=0.7,
... top_p=0.9,
... top_k=50
... )
"""
eval_cfg = self._config.setdefault("eval", {})
eval_cfg["max_new_tokens"] = max_new_tokens
eval_cfg["temperature"] = temperature
eval_cfg["top_p"] = top_p
if batch_size is not None:
eval_cfg["batch_size"] = batch_size
eval_cfg["use_tqdm"] = use_tqdm
eval_cfg.update(kwargs)
return self
[docs] def validate(self) -> None:
"""Validate the current configuration.
Checks that all required fields are present and have valid values.
This is automatically called before training or building engines.
Raises:
ValueError: If configuration is invalid (e.g., missing model name,
invalid dtype, incompatible settings)
"""
validate(self._config)
[docs] def to_json(self, json_path: str | os.PathLike | ePathLike) -> None:
"""Save configuration to JSON file.
Exports the current configuration to a JSON file that can be loaded
later with from_json() or shared with others.
Args:
json_path: Path where the JSON configuration file will be saved.
Will create parent directories if they don't exist.
Example:
>>> elm.to_json("config.json")
>>> # Later or on another machine:
>>> elm2 = eLargeModel.from_json("config.json")
"""
save_elm_config(self._config, json_path)
[docs] def to_dict(self) -> dict[str, Any]:
"""Get configuration as dictionary.
Returns a copy of the full configuration dictionary that can be
modified without affecting the eLargeModel instance.
Returns:
Configuration dictionary with all settings
Example:
>>> config_dict = elm.to_dict()
>>> print(config_dict["model"]["name_or_path"])
>>> # Modify the dict without affecting elm
>>> config_dict["model"]["dtype"] = "fp16"
"""
return dict(self._config)
[docs] def get_from_pretrained_kwargs(self) -> dict[str, Any]:
"""Get kwargs for model.from_pretrained() calls.
Extracts and formats the configuration options that should be passed
to the model's from_pretrained() method, including dtype, sharding,
and quantization settings.
Returns:
Dictionary of from_pretrained arguments ready to use with
EasyDeL model loading functions
Example:
>>> kwargs = elm.get_from_pretrained_kwargs()
>>> # Can be used directly:
>>> model = LlamaForCausalLM.from_pretrained(
... "meta-llama/Llama-2-7b",
... **kwargs
... )
"""
return to_from_pretrained_kwargs(self._config)
[docs] def get_esurge_kwargs(self) -> dict[str, Any]:
"""Get kwargs for eSurge initialization.
Extracts and formats the configuration options for creating an
eSurge engine instance.
Returns:
Dictionary of eSurge arguments including max_model_len,
max_num_seqs, hbm_utilization, and other engine settings
Example:
>>> kwargs = elm.get_esurge_kwargs()
>>> # Can be used directly:
>>> from easydel.inference import eSurge
>>> engine = eSurge(model, **kwargs)
"""
return to_esurge_kwargs(self._config)
[docs] def get_base_config(self, prefer: str = "base") -> dict[str, Any]:
"""Get materialized base configuration.
Resolves the configuration hierarchy, materializing shared base
settings across different configuration sections.
Args:
prefer: Resolution preference when conflicts exist:
- "base": Prefer values from base configuration
- "section": Prefer values from specific sections
Returns:
Base configuration dictionary with resolved values
Example:
>>> # Get configuration with base values taking precedence
>>> base_config = elm.get_base_config(prefer="base")
>>> print(base_config["dtype"]) # Shows the base dtype setting
"""
return materialize_base_config(self._config, prefer)
[docs] def build_model(self, force_rebuild: bool = False) -> EasyDeLBaseModule:
"""Build the EasyDeL model from configuration.
Loads the model using the configured settings including dtype,
sharding, and quantization. The model is cached after first build
unless force_rebuild is True.
Args:
force_rebuild: Force rebuilding even if model is already cached.
Useful when configuration has changed.
Returns:
EasyDeLBaseModule instance ready for training or inference
Raises:
ValueError: If model name/path is not set
RuntimeError: If model loading fails
Example:
>>> elm = eLargeModel.from_pretrained("meta-llama/Llama-2-7b")
>>> elm.set_dtype("bf16")
>>> model = elm.build_model()
"""
if self._model is None or force_rebuild:
if not self.model_name:
raise ValueError("Model name/path must be set before building")
self._model = build_model(self._config)
return self._model
[docs] def build_tokenizer(self, force_rebuild: bool = False) -> AutoTokenizer:
"""Build or get the tokenizer for the model.
Loads the tokenizer from the model path or a separately specified
tokenizer path. The tokenizer is cached after first build.
Args:
force_rebuild: Force rebuilding even if tokenizer is already cached.
Useful when switching between different tokenizers.
Returns:
AutoTokenizer instance configured for the model
Raises:
ValueError: If tokenizer path cannot be determined
Example:
>>> tokenizer = elm.build_tokenizer()
>>> tokens = tokenizer("Hello world", return_tensors="np")
"""
if self._tokenizer is None or force_rebuild:
tok_path = self._config["model"].get("tokenizer", self.model_name)
if not tok_path:
raise ValueError("Tokenizer path must be set")
self._tokenizer = AutoTokenizer.from_pretrained(tok_path)
return self._tokenizer
[docs] def build_esurge(self) -> "eSurge":
"""Build the eSurge inference engine.
Creates an eSurge engine instance configured with the current settings.
Automatically builds the model if not already built.
Returns:
eSurge instance ready for batch inference
Example:
>>> elm.set_esurge(max_num_seqs=32, hbm_utilization=0.9)
>>> engine = elm.build_esurge()
>>> # Use engine for batch inference
>>> results = engine.generate(prompts, max_tokens=100)
"""
self.build_model()
return build_esurge(self._config, self._model)
[docs] def build_teacher_model(self) -> EasyDeLBaseModule | None:
"""Build the teacher model for distillation training.
Loads the teacher model using the same loader configuration as the
student model (dtype, sharding, etc.) but with the teacher model path.
Returns:
EasyDeLBaseModule instance for the teacher model, or None if no
teacher model is configured
Example:
>>> elm.set_teacher_model("meta-llama/Llama-2-13b")
>>> teacher = elm.build_teacher_model()
>>> # Teacher model will be used automatically in distillation training
"""
if "teacher_model" not in self._config:
return None
teacher_config = dict(self._config)
teacher_config["model"] = self._config["teacher_model"]
return build_model(teacher_config)
[docs] def build_reference_model(self) -> EasyDeLBaseModule | None:
"""Build the reference model for preference optimization (DPO, etc.).
Loads the reference model using the same loader configuration as the
primary model. The reference model provides a baseline for computing
preference losses in DPO, ORPO, and similar methods.
Returns:
EasyDeLBaseModule instance for the reference model, or None if no
reference model is configured
Example:
>>> elm.set_reference_model("meta-llama/Llama-2-7b-hf")
>>> reference = elm.build_reference_model()
>>> # Reference model will be used automatically in DPO training
"""
if "reference_model" not in self._config:
return None
reference_config = dict(self._config)
reference_config["model"] = self._config["reference_model"]
return build_model(reference_config)
[docs] def build_dataset(self):
"""Build dataset from mixture configuration.
Creates a dataset from the configured mixture of data sources.
Supports multiple formats (JSON, Parquet, CSV) and can combine
multiple data sources into a single dataset.
Returns:
Dataset: The loaded and processed dataset ready for training,
or None if no mixture is configured
Example:
>>> elm = eLargeModel()
>>> elm.add_dataset("train.json", dataset_type="json", content_field="text")
>>> elm.add_dataset("valid/*.parquet", dataset_type="parquet", content_field="content")
>>> dataset = elm.build_dataset()
>>> print(f"Dataset size: {len(dataset)}")
"""
return build_dataset(self._config)
[docs] def build_sharded_source(self) -> "ShardedDataSource | None":
"""Build dataset as ShardedDataSource for use with new data pipeline.
Creates a ShardedDataSource from the configured mixture of data sources.
This uses the new data architecture that supports lazy transforms,
efficient streaming, and better integration with trainers.
Returns:
ShardedDataSource: The data source ready for training, or None
if no mixture is configured
Example:
>>> elm = eLargeModel()
>>> elm.add_dataset("train.json", dataset_type="json", content_field="text")
>>> elm.set_mixture(use_sharded_source=True)
>>> source = elm.build_sharded_source()
>>> for batch in source.open_shard(source.shard_names[0]):
... process(batch)
"""
return build_sharded_source(self._config)
[docs] def get_train_source(self) -> "ShardedDataSource | Dataset | None":
"""Get training data as ShardedDataSource or Dataset.
Automatically selects the appropriate data format based on the
`use_sharded_source` configuration option.
Returns:
ShardedDataSource if use_sharded_source=True in mixture config,
otherwise HuggingFace Dataset. Returns None if no mixture configured.
Example:
>>> elm = eLargeModel()
>>> elm.add_dataset("train.json", dataset_type="json")
>>> elm.set_mixture(use_sharded_source=True) # Use new pipeline
>>> data = elm.get_train_source() # Returns ShardedDataSource
>>>
>>> elm.set_mixture(use_sharded_source=False) # Use legacy pipeline
>>> data = elm.get_train_source() # Returns HF Dataset
"""
mixture_cfg = self._config.get("mixture", {})
if mixture_cfg.get("use_sharded_source", True):
return self.build_sharded_source()
return self.build_dataset()
[docs] def get_data_mixture_kwargs(self) -> dict[str, Any]:
"""Get kwargs for DatasetMixture initialization.
Extracts and formats the mixture configuration for use with
the DatasetMixture class.
Returns:
Dictionary of DatasetMixture arguments including informs,
batch_size, streaming settings, and other mixture options
"""
return to_data_mixture_kwargs(self._config)
[docs] def clear_cache(self) -> None:
"""Clear cached model, tokenizer, and inference engine instances.
This is useful when you want to reload models with different
configurations or free memory after model operations.
"""
self._model = None
self._tokenizer = None
[docs] def set_trainer(self, trainer_type: str, **kwargs) -> eLargeModel:
"""Configure trainer settings.
Sets the training paradigm and associated hyperparameters.
Args:
trainer_type: Type of trainer to use:
- "sft": Supervised Fine-Tuning
- "dpo": Direct Preference Optimization
- "orpo": Odds Ratio Preference Optimization
- "grpo": Group Relative Policy Optimization
- "reward": Reward model training
- "distillation": Knowledge distillation
- "base": Basic trainer for custom training loops
**kwargs: Trainer-specific configuration options:
Common options:
- learning_rate: Learning rate (default: 5e-5)
- num_train_epochs: Number of training epochs
- per_device_train_batch_size: Batch size per device
- gradient_accumulation_steps: Gradient accumulation steps
- warmup_steps: Number of warmup steps
- output_dir: Directory to save checkpoints
DPO-specific:
- beta: KL regularization coefficient
- loss_type: "sigmoid", "ipo", "hinge"
Distillation-specific:
- temperature: Distillation temperature
- alpha: Weight for distillation loss
Returns:
Self for method chaining
Example:
>>> # SFT training
>>> elm.set_trainer(
... "sft",
... learning_rate=2e-5,
... num_train_epochs=3,
... per_device_train_batch_size=4
... )
>>>
>>> # DPO training
>>> elm.set_trainer(
... "dpo",
... beta=0.1,
... learning_rate=1e-6
... )
"""
trainer_cfg = self._config.setdefault("trainer", {})
trainer_cfg["trainer_type"] = trainer_type
trainer_cfg.update(kwargs)
return self
[docs] def get_trainer_config(self) -> dict[str, Any]:
"""Get normalized trainer configuration.
This method processes the raw trainer configuration and applies
defaults and normalization for the specified trainer type.
Returns:
Normalized trainer configuration dictionary with all required
fields populated with defaults where necessary.
"""
raw_config = self._config.get("trainer", {})
return normalize_trainer_config(raw_config)
[docs] def train(
self,
train_dataset: Dataset | ShardedDataSource | None = None,
eval_dataset: Dataset | ShardedDataSource | None = None,
base_state_class: type[EasyDeLState] | None = None,
args_class: type[TrainingArguments] | None = None,
trainer_class: type[Trainer] | None = None,
**build_kwargs: Unpack[BuildTrainerKws],
):
"""Train the model with the configured settings.
This is a high-level convenience method that orchestrates the entire
training pipeline:
1. Validates configuration
2. Builds the model if not already built
3. Creates the dataset from mixture configuration if not provided
4. Builds reference/teacher models if needed
5. Creates the appropriate trainer
6. Runs training and returns results
Args:
train_dataset: Optional training dataset (Dataset or ShardedDataSource).
If None, will build from mixture configuration.
eval_dataset: Optional evaluation dataset for validation during training.
base_state_class: Optional custom EasyDeLState class for model state
management. Use for custom model implementations.
args_class: Optional custom TrainingArguments class. If None, will
auto-select based on trainer_type.
trainer_class: Optional custom Trainer class. If None, will auto-select
based on trainer_type.
**build_kwargs: Additional kwargs for trainer building:
- data_collator: Custom data collator function
- formatting_func: Function to format examples (SFT)
- reward_processing_classes: Processing classes for rewards (GRPO)
- data_tokenize_fn: Custom tokenization function
- reference_model: Override reference model
- reward_model: Override reward model
- teacher_model: Override teacher model
- reward_funcs: Custom reward functions
Returns:
Training results from the trainer, including metrics and final model state
Example:
Basic SFT training:
>>> elm = eLargeModel.from_pretrained("meta-llama/Llama-2-7b")
>>> elm.add_dataset("train.json", dataset_type="json")
>>> elm.set_trainer("sft", learning_rate=2e-5, num_train_epochs=3)
>>> results = elm.train()
DPO training with custom datasets:
>>> train_data = load_dataset("preference_data", split="train")
>>> eval_data = load_dataset("preference_data", split="test")
>>> elm.set_trainer("dpo", beta=0.1)
>>> elm.set_reference_model("meta-llama/Llama-2-7b")
>>> results = elm.train(train_dataset=train_data, eval_dataset=eval_data)
Custom trainer with formatting function:
>>> def format_fn(examples):
... return [f"Question: {q}\nAnswer: {a}"
... for q, a in zip(examples["question"], examples["answer"])]
>>> results = elm.train(formatting_func=format_fn)
"""
self.validate()
trainer = self.build_trainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
base_state_class=base_state_class,
args_class=args_class,
trainer_class=trainer_class,
**build_kwargs,
)
logger.info("Starting training with configuration:")
logger.info(f" Model: {self.model_name}")
logger.info(f" Trainer Arguments:\n{pprint.pformat(self._config.get('trainer', {}))}")
logger.info("Beginning training...")
results = trainer.train()
logger.info("Training completed successfully")
self._model = trainer.model
return results
[docs] def build_training_arguments(self, args_class: TrainingArguments | None = None, **overrides):
"""Build TrainingArguments for the configured trainer.
Args:
args_class: Optional custom TrainingArguments class. If not provided,
will automatically select based on trainer_type.
**overrides: Override specific configuration values
Returns:
TrainingArguments instance for the configured trainer type
(e.g., DPOConfig for DPO training, SFTConfig for SFT)
"""
trainer_cfg = self.get_trainer_config()
trainer_cfg.update(overrides)
if args_class is None:
trainer_type = trainer_cfg.get("trainer_type", "sft")
args_class = get_training_arguments_class(trainer_type)
config_for_args = {k: v for k, v in trainer_cfg.items() if k != "trainer_type"}
try:
return args_class(**config_for_args)
except TypeError:
import inspect
sig = inspect.signature(args_class.__init__)
valid_params = set(sig.parameters.keys()) - {"self"}
filtered_config = {k: v for k, v in config_for_args.items() if k in valid_params}
return args_class(**filtered_config)
[docs] def build_trainer(
self,
train_dataset: Dataset | ShardedDataSource | None = None,
eval_dataset: Dataset | ShardedDataSource | None = None,
reference_model: EasyDeLBaseModule | None = None,
reward_model: EasyDeLBaseModule | None = None,
teacher_model: EasyDeLBaseModule | None = None,
reward_funcs: Any | None = None,
base_state_class: type[EasyDeLState] | None = None,
args_class: type[TrainingArguments] | None = None,
trainer_class: type[Trainer] | None = None,
**kwargs,
) -> Trainer:
"""Build a trainer instance with the configured settings.
Creates and configures a trainer based on the trainer_type setting.
Automatically builds required models and datasets if not provided.
Args:
train_dataset: Training dataset (Dataset or ShardedDataSource).
If None, builds from mixture config using get_train_source().
eval_dataset: Evaluation dataset for validation metrics.
reference_model: Reference model for DPO/ORPO. If None, builds from
reference_model configuration if present.
reward_model: Reward model for GRPO. If None, builds from config.
teacher_model: Teacher model for distillation. If None, builds from
teacher_model configuration if present.
reward_funcs: Custom reward functions for GRPO. Alternative to reward_model.
base_state_class: Custom EasyDeLState class for model state management.
args_class: Custom TrainingArguments class. Auto-selected if None.
trainer_class: Custom Trainer class. Auto-selected if None.
**kwargs: Additional trainer configuration overrides
Returns:
Configured trainer instance ready for training
Raises:
ValueError: If required models or datasets are not configured
Example:
>>> # Build trainer with auto-configuration
>>> trainer = elm.build_trainer()
>>>
>>> # Build trainer with custom dataset
>>> custom_data = load_dataset("custom_data")
>>> trainer = elm.build_trainer(train_dataset=custom_data)
>>>
>>> # Build DPO trainer with custom reference model
>>> ref_model = elm.build_reference_model()
>>> trainer = elm.build_trainer(
... trainer_type="dpo",
... reference_model=ref_model
... )
"""
from easydel.infra.base_state import EasyDeLState
trainer_cfg = self.get_trainer_config()
trainer_type = trainer_cfg.get("trainer_type", "sft")
if self._model is None:
self.build_model()
if self._tokenizer is None:
self.build_tokenizer()
if trainer_class is None:
trainer_class = get_trainer_class(trainer_type)
training_args = self.build_training_arguments(args_class=args_class, **kwargs)
if train_dataset is None and "mixture" in self._config:
# Use new get_train_source() which auto-selects based on use_sharded_source
train_dataset = self.get_train_source()
trainer_kwargs = {}
model = self._model
if base_state_class is not None:
model = model.to_state(base_state_class)
if trainer_type == "base":
trainer_kwargs["arguments"] = training_args
if isinstance(model, EasyDeLState):
trainer_kwargs["model_state"] = model
else:
trainer_kwargs["model"] = model
trainer_kwargs["dataset_train"] = train_dataset
trainer_kwargs["dataset_eval"] = eval_dataset
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
elif trainer_type == "dpo":
if reference_model is None:
reference_model = self.build_reference_model()
if reference_model is not None and base_state_class is not None:
reference_model = reference_model.to_state(base_state_class)
trainer_kwargs["arguments"] = training_args
trainer_kwargs["model"] = model
trainer_kwargs["reference_model"] = reference_model
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
elif trainer_type == "orpo":
trainer_kwargs["arguments"] = training_args
trainer_kwargs["model"] = model
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
elif trainer_type == "grpo":
trainer_kwargs["arguments"] = training_args
trainer_kwargs["model"] = model
trainer_kwargs["reward_funcs"] = reward_funcs if reward_funcs is not None else reward_model
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["reward_processing_classes"] = kwargs.get("reward_processing_classes", None)
trainer_kwargs["data_tokenize_fn"] = kwargs.get("data_tokenize_fn", None)
elif trainer_type == "sft":
trainer_kwargs["arguments"] = training_args
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["model"] = model
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["formatting_func"] = kwargs.get("formatting_func", None)
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
elif trainer_type == "reward":
trainer_kwargs["arguments"] = training_args
trainer_kwargs["model"] = model
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
elif trainer_type == "distillation":
if teacher_model is None:
teacher_model = self.build_teacher_model()
if teacher_model is not None and base_state_class is not None:
teacher_model = teacher_model.to_state(base_state_class)
trainer_kwargs["arguments"] = training_args
trainer_kwargs["student_model"] = model
trainer_kwargs["teacher_model"] = teacher_model
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs["data_collator"] = kwargs.get("data_collator", None)
else:
trainer_kwargs["arguments"] = training_args
trainer_kwargs["model"] = model
trainer_kwargs["processing_class"] = self._tokenizer
trainer_kwargs["train_dataset"] = train_dataset
trainer_kwargs["eval_dataset"] = eval_dataset
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
for key, value in kwargs.items():
if key not in trainer_kwargs and value is not None:
if key not in ["data_collator", "formatting_func", "reward_processing_classes", "data_tokenize_fn"]:
trainer_kwargs[key] = value
return trainer_class(**trainer_kwargs)
[docs] def eval(
self,
tasks: str | list[str],
engine: typing.Literal["esurge", "auto"] | Any = "auto",
num_fewshot: int = 0,
output_path: str | None = None,
) -> dict[str, Any]:
"""Run evaluation on specified tasks using lm-evaluation-harness.
This method provides a unified interface for evaluating models using the
eSurge engine with the lm-evaluation-harness framework.
Args:
tasks: Task name(s) to evaluate on. Can be a single task string or list of tasks.
Common tasks include:
- Language understanding: "hellaswag", "winogrande", "piqa", "arc_easy", "arc_challenge"
- Math: "gsm8k", "math", "minerva_math"
- Knowledge: "mmlu", "triviaqa", "naturalquestions"
- Reasoning: "bbh", "boolq", "copa"
- Truthfulness: "truthfulqa_mc1", "truthfulqa_mc2"
- Coding: "humaneval", "mbpp"
Full list: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks
engine: Inference engine to use. Options:
- "esurge": Use eSurge engine (high throughput)
- "auto": Automatically select based on configuration (default)
- An existing eSurge instance for custom configuration
num_fewshot: Number of few-shot examples to use (default: 0 for zero-shot).
Different tasks may have different recommended values:
- MMLU: typically 5-shot
- GSM8K: typically 8-shot
- HellaSwag: typically 0-shot
output_path: Optional path to save evaluation results as JSON.
Results include detailed metrics, task versions, and configuration.
Returns:
Dictionary containing evaluation results with structure:
{
"results": {
task_name: {
metric_name: value, # e.g., "acc": 0.85, "acc_stderr": 0.02
...
},
...
},
"versions": {task_name: version_string, ...},
"config": {"model": ..., "num_fewshot": ..., ...}
}
Example:
Basic zero-shot evaluation:
>>> elm = eLargeModel.from_pretrained("meta-llama/Llama-2-7b")
>>> results = elm.eval("hellaswag")
>>> print(f"HellaSwag accuracy: {results['results']['hellaswag']['acc']:.2%}")
Few-shot evaluation with multiple tasks:
>>> elm.set_esurge(max_num_seqs=64, hbm_utilization=0.9)
>>> results = elm.eval(
... ["gsm8k", "mmlu", "truthfulqa_mc1"],
... engine="esurge",
... num_fewshot=5,
... output_path="eval_results.json"
... )
>>> for task, metrics in results["results"].items():
... print(f"{task}: {metrics.get('acc', metrics.get('exact_match')):.2%}")
Evaluation with custom settings:
>>> elm.set_eval(
... max_new_tokens=512,
... temperature=0.0, # Greedy decoding
... batch_size=32
... )
>>> results = elm.eval(["humaneval", "mbpp"])
Raises:
ImportError: If lm-eval is not installed (install with: pip install lm-eval)
ValueError: If invalid engine type or model not configured
RuntimeError: If evaluation fails during execution
Note:
The evaluation uses settings from set_eval() for generation parameters.
Default settings are optimized for deterministic evaluation (temperature=0).
"""
try:
from lm_eval import evaluator # type:ignore
except ImportError as e:
raise ImportError("lm-eval is required for evaluation. Install with: pip install lm-eval") from e
if isinstance(tasks, str):
tasks = [tasks]
if self._tokenizer is None:
self.build_tokenizer()
eval_config = self._config.get("eval", {}).copy()
batch_size = eval_config.pop("batch_size", None)
max_new_tokens = eval_config.pop("max_new_tokens", 2048)
temperature = eval_config.pop("temperature", 0.0)
top_p = eval_config.pop("top_p", 0.95)
eval_adapter = None
engine_instance = None
if isinstance(engine, str):
if engine == "auto":
engine = "esurge"
if engine == "esurge":
from easydel.inference.evaluations import eSurgeLMEvalAdapter
engine_instance = self.build_esurge()
if batch_size is None:
batch_size = self._config.get("esurge", {}).get("max_num_seqs", 32)
eval_adapter = eSurgeLMEvalAdapter(
surge=engine_instance,
processor=self._tokenizer,
max_length=self._config.get("esurge", {}).get("max_model_len", 8192),
max_new_tokens=max_new_tokens,
batch_size=batch_size,
temperature=temperature,
top_p=top_p,
)
else:
raise ValueError(f"Unknown engine type: {engine}")
else:
engine_type = type(engine).__name__
if "eSurge" in engine_type:
from easydel.inference.evaluations import eSurgeLMEvalAdapter
if batch_size is None:
batch_size = getattr(engine, "max_num_seqs", 32)
eval_adapter = eSurgeLMEvalAdapter(
surge=engine,
processor=self._tokenizer,
max_length=getattr(engine, "max_model_len", 8192),
max_new_tokens=max_new_tokens,
batch_size=batch_size,
temperature=temperature,
top_p=top_p,
)
else:
raise ValueError(f"Unknown engine instance type: {engine_type}")
if eval_adapter is None:
raise RuntimeError("Failed to create evaluation adapter")
try:
logger.info(f"Starting evaluation on tasks: {tasks}")
logger.info(f"Using {engine if isinstance(engine, str) else type(engine).__name__} engine")
logger.info(f"Batch size: {batch_size}, Few-shot: {num_fewshot}")
results = evaluator.simple_evaluate(
model=eval_adapter,
tasks=tasks,
num_fewshot=num_fewshot,
batch_size=batch_size,
device="cpu",
**eval_config,
)
if output_path:
ePath(output_path).write_text(json.dumps(results, indent=2))
logger.info(f"eval results saved to: {output_path}")
logger.info("evaluation summary:")
for task, metrics in results.get("results", {}).items():
logger.info(f"{task}:")
for metric, value in metrics.items():
if isinstance(value, int | float):
logger.info(f" {metric}: {value:.4f}" if isinstance(value, float) else f" {metric}: {value}")
return results
finally:
if hasattr(eval_adapter, "stop"):
eval_adapter.stop()
def __repr__(self) -> str:
"""Developer-friendly string representation of eLargeModel.
Returns:
A concise representation showing key configuration like model,
task, dtype, quantization, and sharding.
"""
task_str = f"TaskType.{self.task.name}" if self.task else "None"
dtype_str = repr(self._config.get("loader", {}).get("dtype", "default"))
quant = self._config.get("quantization", {})
quant_str = f", quantization={quant['method']!r}" if quant.get("method") else ""
sharding = self._config.get("sharding", {})
axis_dims = sharding.get("axis_dims")
shard_str = f", sharding={axis_dims}" if axis_dims and axis_dims != (1, 1, 1, -1, 1) else ""
return f"eLargeModel(model={self.model_name!r}, task={task_str}, dtype={dtype_str}{quant_str}{shard_str})"
def __str__(self) -> str:
"""Human-readable string representation with formatted configuration.
Returns:
A nicely formatted multi-line string showing all configured
components including model, loading options, sharding,
quantization, training, and inference settings.
"""
w = 53 # inner width (content area)
def _fmt(val: Any) -> str:
"""Format a value for display, handling enums and special types."""
if val is None:
return "none"
if hasattr(val, "value"):
return str(val.value) if val.value else "none"
if hasattr(val, "name"):
return val.name.lower()
return str(val)
def _line(content: str) -> str:
"""Create a padded line within the box."""
return f"║ {content:<{w}}║"
def _sep() -> str:
"""Create a separator line."""
return f"╟{'─' * (w + 1)}╢"
lines = []
lines.append(f"╔{'═' * (w + 1)}╗")
lines.append(f"║{'eLargeModel':^{w + 1}}║")
lines.append(f"║{self.model_name or 'not set':^{w + 1}}║")
lines.append(f"╠{'═' * (w + 1)}╣")
# Loader & Task
loader = self._config.get("loader", {})
dtype_str = loader.get("dtype", "default")
prec_str = loader.get("precision", "default")
task_str = self.task.name.lower() if self.task else "auto"
lines.append(_line(f"▸ dtype: {dtype_str:<12} ▸ task: {task_str}"))
lines.append(_line(f"▸ precision: {prec_str}"))
# Sharding
sharding = self._config.get("sharding", {})
if sharding.get("axis_dims"):
dims = sharding["axis_dims"]
auto = "auto" if sharding.get("auto_shard_model") else "manual"
dims_str = ",".join(str(d) for d in dims)
lines.append(_line(f"▸ shard: ({dims_str}) {auto}"))
lines.append(_sep())
# Config section
base_cfg = self._config.get("base_config", {}).get("values", {})
if base_cfg:
if "attn_mechanism" in base_cfg:
lines.append(_line(f"▸ attn: {_fmt(base_cfg['attn_mechanism'])}"))
if "moe_method" in base_cfg:
lines.append(_line(f"▸ moe: {_fmt(base_cfg['moe_method'])}"))
if "gradient_checkpointing" in base_cfg:
gc = base_cfg["gradient_checkpointing"]
gc_str = _fmt(gc) or "disabled"
lines.append(_line(f"▸ grad_ckpt: {gc_str}"))
# Quantization
quant = self._config.get("quantization", {})
if quant.get("method"):
lines.append(_line(f"▸ quant: {quant['method']} (block:{quant.get('block_size', 128)})"))
# eSurge
esurge = self._config.get("esurge", {})
if esurge and any(esurge.get(k) for k in ["max_model_len", "max_num_seqs", "hbm_utilization"]):
lines.append(_sep())
lines.append(_line("eSurge"))
ctx = esurge.get("max_model_len", 0)
seqs = esurge.get("max_num_seqs", 0)
hbm = esurge.get("hbm_utilization", 0)
lines.append(_line(f"▸ {ctx:,} context x {seqs} sequences"))
parts = []
if hbm:
parts.append(f"{hbm:.0%} HBM")
if esurge.get("page_size"):
parts.append(f"page:{esurge['page_size']}")
if esurge.get("enable_prefix_caching"):
parts.append("prefix_cache")
if esurge.get("min_input_pad"):
parts.append(f"pad:{esurge['min_input_pad']}")
if parts:
lines.append(_line(f"▸ {' │ '.join(parts)}"))
# Training
trainer = self._config.get("trainer", {})
if trainer.get("trainer_type"):
lines.append(_sep())
lines.append(_line(f"Training: {trainer['trainer_type'].upper()}"))
parts = []
if trainer.get("learning_rate"):
parts.append(f"lr:{trainer['learning_rate']:.2e}")
if trainer.get("num_train_epochs"):
parts.append(f"epochs:{trainer['num_train_epochs']}")
if trainer.get("total_batch_size"):
parts.append(f"batch:{trainer['total_batch_size']}")
if parts:
lines.append(_line(f"▸ {' │ '.join(parts)}"))
# Status
lines.append(_sep())
model_icon = "●" if self._model is not None else "○"
tok_icon = "●" if self._tokenizer is not None else "○"
model_status = "loaded" if self._model is not None else "not loaded"
tok_status = "loaded" if self._tokenizer is not None else "not loaded"
lines.append(_line(f"{model_icon} model: {model_status} {tok_icon} tokenizer: {tok_status}"))
lines.append(f"╚{'═' * (w + 1)}╝")
return "\n".join(lines)