Source code for easydel.trainers.ray_scaler.distributed_trainer

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ray-based distributed trainer implementation for EasyDeL.

This module provides a distributed training implementation using Ray for scaling
language model training across multiple GPUs and nodes. It integrates Ray's
distributed computing capabilities with EasyDeL's training infrastructure to
enable efficient large-scale model training.

The module includes:
- RayDistributedTrainer: Main class for distributed training with Ray
- Integration with Ray Train for distributed data loading and gradient synchronization
- Support for both data and model parallelism strategies
- Automatic resource management and fault tolerance
- Checkpointing and recovery mechanisms for long-running training jobs

Key Components:
- Automatic distribution of training data across workers
- Gradient synchronization using Ray's collective communication
- Dynamic resource allocation and load balancing
- Integration with Ray Tune for hyperparameter optimization
- Support for heterogeneous hardware configurations

The trainer abstracts away the complexity of distributed training, allowing users
to scale from single GPU to multi-node clusters with minimal code changes.
"""

from __future__ import annotations

import copy
import json
import os
import typing as tp
from functools import cached_property

import jax
from eformer.escale import PartitionAxis
from eformer.loggings import get_logger
from eformer.mpric import DTYPE_TO_STRING_MAP, STRING_TO_DTYPE_MAP
from eformer.paths import ePath
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from pydantic import BaseModel
from transformers import AutoTokenizer, PreTrainedTokenizer

from easydel.infra import EasyDeLBaseConfig, EasyDeLBaseModule, EasyDeLState
from easydel.infra.etils import EasyDeLGradientCheckPointers
from easydel.infra.factory import TaskType
from easydel.modules.auto.auto_configuration import get_modules_by_type
from easydel.utils import Registry

from ..base_trainer import BaseTrainer
from ..trainer.trainer import Trainer
from ..training_configurations import TrainingArguments

if tp.TYPE_CHECKING:
    from datasets import Dataset

logger = get_logger("RayTrainer")


[docs]@Registry.register("trainer", "ray_dist") class RayDistributedConfig(BaseModel): """ Configuration for RayDistributedTrainer that can be persisted to JSON. This class handles serialization and deserialization of distributed training configurations, with special handling for JAX dtypes and PartitionAxis objects. Attributes: pretrained_model_name_or_path: Path or identifier for the pretrained model model_task: The task type for the model (e.g., CAUSAL_LM, SEQ2SEQ) model_type: The model architecture type (e.g., 'llama', 'gpt2') offload_backend: Backend device for offloading (e.g., 'cpu', 'gpu') config_scaling_variables: Variables to scale by scaling_index (e.g., hidden_size) config_variables: Fixed configuration variables (e.g., dtype, precision) Notes: - JAX dtype fields are converted to/from strings for JSON serialization - PartitionAxis objects are converted to/from dictionary representation - Use _saving_preprocess() before saving and _loading_postprocess() after loading """ pretrained_model_name_or_path: str model_task: TaskType | None = None model_type: str | None = None offload_backend: str | None = None config_scaling_variables: dict[str, int] | None = None config_variables: dict[str, tp.Any] | None = None def _saving_preprocess(self): """Convert dtypes and PartitionAxis to JSON-serializable formats before saving.""" if self.config_variables: for k, v in list(self.config_variables.items()): if v in STRING_TO_DTYPE_MAP.values(): self.config_variables[k] = DTYPE_TO_STRING_MAP[v] if "partition_axis" in self.config_variables and isinstance( self.config_variables["partition_axis"], PartitionAxis ): self.config_variables["partition_axis"] = self.config_variables["partition_axis"].__dict__ if self.config_scaling_variables: for k, v in list(self.config_scaling_variables.items()): if v in STRING_TO_DTYPE_MAP.values(): self.config_scaling_variables[k] = DTYPE_TO_STRING_MAP[v] def _loading_postprocess(self): """Convert string representations back to dtypes and PartitionAxis after loading.""" if self.config_variables: for k, v in list(self.config_variables.items()): if v in DTYPE_TO_STRING_MAP.values(): self.config_variables[k] = STRING_TO_DTYPE_MAP[v] if "partition_axis" in self.config_variables: pa = self.config_variables["partition_axis"] if not isinstance(pa, PartitionAxis): self.config_variables["partition_axis"] = PartitionAxis(**pa) if self.config_scaling_variables: for k, v in list(self.config_scaling_variables.items()): if v in DTYPE_TO_STRING_MAP.values(): self.config_scaling_variables[k] = STRING_TO_DTYPE_MAP[v]
[docs]class RayDistributedTrainer: """ Distributed trainer for Ray-based training with EasyDeL models. This class provides a lightweight wrapper for distributed training that: - Manages model configuration and scaling for different nodes - Handles model/state initialization and checkpoint loading - Delegates actual training to the underlying Trainer implementation The trainer supports: - Dynamic model scaling based on scaling_index - Automatic tokenizer/processor setup with padding configuration - Flexible checkpoint loading from various sources - Integration with Ray for distributed training orchestration Key Design Principles: - Resume logic is handled by BaseTrainer (set arguments.resume_if_possible=True) - State sharding is deferred to the main Trainer according to partition rules - Explicit checkpoint paths are used without automatic run-* resolution Attributes: model_task: The task type for the model (e.g., CAUSAL_LM) model_type: The model architecture type (e.g., 'llama') model_class: The EasyDeL model class to instantiate state_class: The state class for model checkpointing offload_backend: Backend for memory offloading trainer_module: The trainer class to use for actual training CONFIG_SCALING_VARIABLES: Variables that scale with scaling_index CONFIG_VARIABLES: Fixed configuration variables """ # Model identity model_task: TaskType model_type: str model_class: type[EasyDeLBaseModule] state_class: type[EasyDeLState] offload_backend: str trainer_module: type[BaseTrainer | Trainer] CONFIG_SCALING_VARIABLES: tp.ClassVar[dict[str, int]] = { "hidden_size": 256, "intermediate_size": 256 * 4, "moe_intermediate_size": 256 * 2, "num_attention_heads": 2, "num_key_value_heads": 1, } CONFIG_VARIABLES: tp.ClassVar[dict[str, tp.Any]] = { "dtype": jnp.bfloat16, "param_dtype": jnp.bfloat16, "precision": lax.Precision.DEFAULT, "seed": 654, "max_position_embeddings": 2**13, "gradient_checkpointing": EasyDeLGradientCheckPointers.NONE, "initializer_range": 0.02, "partition_axis": PartitionAxis(), "attn_mechanism": "auto", "attn_dtype": jnp.bfloat16, "attn_softmax_dtype": jnp.bfloat16, "sharding_axis_names": ("dp", "fsdp", "ep", "tp", "sp"), "sharding_axis_dims": (1, -1, 1, 1, 1), "sharding_dcn_axis_dims": (1, -1, 1, 1, 1), } _processor_loader_class: type[PreTrainedTokenizer] = AutoTokenizer def __init__( self, pretrained_model_name_or_path: str, bucket_path: str | None = None, model_task: TaskType | None = None, model_type: str | None = None, model_class: type[EasyDeLBaseModule] | None = None, state_class: type[EasyDeLState] | None = None, offload_backend: str | None = None, trainer_module: type[BaseTrainer | Trainer] | None = None, config_scaling_variables: dict[str, int] | None = None, config_variables: dict[str, tp.Any] | None = None, ): """ Initialize the RayDistributedTrainer. Args: pretrained_model_name_or_path: Path or identifier for the pretrained model bucket_path: Optional path to load checkpoints from cloud storage model_task: Task type (inferred from model_class if not provided) model_type: Model architecture type (inferred from model_class if not provided) model_class: EasyDeL model class to use (auto-resolved if not provided) state_class: State class for checkpointing (defaults to EasyDeLState) offload_backend: Backend for memory offloading (defaults to 'cpu') trainer_module: Trainer class to use (defaults to Trainer) config_scaling_variables: Variables to scale with scaling_index config_variables: Fixed configuration variables Raises: AssertionError: If model class cannot be resolved or parameters are inconsistent """ self.pretrained_model_name_or_path = pretrained_model_name_or_path if model_task is None or model_type is None: assert model_task is None and model_type is None, ( "If one of model_task or model_type is None, both must be None." ) assert model_class is not None, "model_class must be provided when model_task/model_type are omitted." model_type = model_class._model_type model_task = model_class._model_task elif model_class is not None: logger.warning( "Both model_class and model_type/model_task provided. Using model_class and inferring type/task from it." ) model_type = model_class._model_type model_task = model_class._model_task if model_class is None: assert model_type is not None and model_task is not None, ( "model_type and model_task must be provided if model_class is not specified." ) _, resolved_class = get_modules_by_type(model_type=model_type, task_type=model_task) assert resolved_class is not None, f"Could not resolve model class for {model_type}/{model_task}" self.model_class = resolved_class else: self.model_class = model_class self.config_scaling_variables = copy.deepcopy(self.CONFIG_SCALING_VARIABLES) self.config_variables = copy.deepcopy(self.CONFIG_VARIABLES) if config_scaling_variables is not None: self.config_scaling_variables.update(config_scaling_variables) if config_variables is not None: self.config_variables.update(config_variables) self.bucket_path = bucket_path self.model_task = model_task self.model_type = model_type self.offload_backend = offload_backend if offload_backend is not None else "cpu" self.state_class = state_class if state_class is not None else EasyDeLState self.trainer_module = trainer_module if trainer_module is not None else Trainer
[docs] @classmethod def from_config( cls, path: str | os.PathLike, model_class: type[EasyDeLBaseModule] | None = None, state_class: type[EasyDeLState] | None = None, trainer_module: type[BaseTrainer | Trainer] | None = None, ): """ Create a RayDistributedTrainer from a saved configuration file. Args: path: Path to the JSON configuration file model_class: Optional model class override state_class: Optional state class override trainer_module: Optional trainer module override Returns: RayDistributedTrainer: Initialized trainer instance """ cfg = RayDistributedConfig(**json.loads(ePath(path).read_text())) cfg._loading_postprocess() return cls( pretrained_model_name_or_path=cfg.pretrained_model_name_or_path, model_task=cfg.model_task, model_type=cfg.model_type, config_scaling_variables=cfg.config_scaling_variables, config_variables=cfg.config_variables, offload_backend=cfg.offload_backend, trainer_module=trainer_module, state_class=state_class, model_class=model_class, )
[docs] def save_config(self, path: str | os.PathLike): """ Save the current configuration to a JSON file. Args: path: Path where the configuration will be saved """ cfg = RayDistributedConfig( pretrained_model_name_or_path=self.pretrained_model_name_or_path, model_task=self.model_task, model_type=self.model_type, offload_backend=self.offload_backend, config_scaling_variables=self.config_scaling_variables, config_variables=self.config_variables, ) cfg._saving_preprocess() ePath(path).write_text(cfg.model_dump_json(indent=2))
[docs] def load_processor(self) -> PreTrainedTokenizer: """ Load the tokenizer/processor for the model. Returns: PreTrainedTokenizer: Loaded tokenizer with padding configuration Notes: - Automatically sets pad_token to eos_token if not defined - Logs a warning when falling back to eos_token for padding """ tok_cls = self._processor_loader_class tokenizer = tok_cls.from_pretrained(self.pretrained_model_name_or_path) has_eos = hasattr(tokenizer, "eos_token_id") if getattr(tokenizer, "pad_token_id", None) is None and has_eos: logger.warning("Tokenizer has no pad_token. Falling back to eos_token for padding.") tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer
@cached_property def processor(self) -> PreTrainedTokenizer: """Cached property for the tokenizer/processor.""" return self.load_processor()
[docs] @staticmethod def extract_column_names(dataset: Dataset) -> list[str] | None: """ Extract column names from a dataset. Args: dataset: The dataset to extract column names from Returns: list[str] | None: Column names if available, None otherwise """ if hasattr(dataset, "column_names") and dataset.column_names: return list(dataset.column_names) for sample in dataset: return list(sample.keys()) return None
[docs] def process_sample_data( self, sample: tp.Any, max_length: int, padding_side: str = "left", ) -> dict[str, jax.Array]: """ Process a text sample into model inputs. Args: sample: Raw text sample to process max_length: Maximum sequence length padding_side: Side to pad sequences ('left' or 'right') Returns: dict[str, jax.Array]: Tokenized and padded inputs with flattened shapes """ out = self.processor( sample, padding="max_length", max_length=max_length, return_tensors="jax", padding_side=padding_side, return_attention_mask=True, truncation=True, ) return {k: (v.reshape(-1) if hasattr(v, "shape") else v) for k, v in out.items()}
[docs] def process_messages_data( self, messages: tp.Any, max_length: int, padding_side: str = "left", ) -> dict[str, jax.Array]: """ Process chat messages using the tokenizer's chat template. Args: messages: Chat messages to process max_length: Maximum sequence length padding_side: Side to pad sequences ('left' or 'right') Returns: dict[str, jax.Array]: Tokenized and padded inputs with flattened shapes """ out = self.processor.apply_chat_template( messages, padding="max_length", max_length=max_length, return_tensors="jax", padding_side=padding_side, return_dict=True, truncation=True, ) return {k: (v.reshape(-1) if hasattr(v, "shape") else v) for k, v in out.items()}
[docs] def create_config(self, scaling_index: int) -> EasyDeLBaseConfig: """ Create a model configuration with scaled dimensions. Args: scaling_index: Multiplier for scaling variables (e.g., hidden_size) Returns: EasyDeLBaseConfig: Configuration with scaled and fixed variables Notes: - Scaling variables are multiplied by scaling_index - Fixed variables remain unchanged - Useful for creating different model sizes in distributed training """ not_allowed = ["precision", "dtype", "param_dtype"] scaled = {k: v * scaling_index for k, v in copy.deepcopy(self.config_scaling_variables).items()} config_kwargs = {**{k: v for k, v in self.config_variables.items() if k not in not_allowed}, **scaled} config_class = self.model_class.config_class if config_class is None: config_class, _ = get_modules_by_type(model_type=self.model_type, task_type=self.model_task) return config_class(**config_kwargs)
def _get_offload_device(self): """ Get the device for memory offloading. Returns: Device: Preferred local device or first available global device Notes: - Attempts to use local devices first for better performance - Falls back to global devices if local unavailable """ try: devs = jax.local_devices(backend=self.offload_backend) if len(devs) > 0: return devs[0] except Exception: pass return jax.devices(self.offload_backend)[0]
[docs] def create_model( self, config: EasyDeLBaseConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: lax.PrecisionLike | None = None, seed: int = 684, lazy: bool = False, ) -> EasyDeLBaseModule: """ Create a model instance from configuration. Args: config: Model configuration dtype: Computation dtype param_dtype: Parameter dtype precision: JAX precision setting seed: Random seed for initialization lazy: Whether to use lazy initialization (memory efficient) Returns: EasyDeLBaseModule: Initialized model instance """ if precision is None: precision = lax.Precision.DEFAULT init_kwargs = dict( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=nn.Rngs(seed), ) if lazy: return self.model_class.lazy_init(**init_kwargs) return self.model_class.sequential_init(**init_kwargs)
[docs] def convert_model_to_state(self, model: EasyDeLBaseModule) -> EasyDeLState: """ Convert a model module to a state object. Args: model: The model to convert Returns: EasyDeLState: State object for checkpointing Notes: - Does NOT perform sharding (handled by trainer) - Uses the configured state_class for conversion """ return model.to_state(self.state_class)
[docs] def create_model_from_config(self, scaling_index: int) -> EasyDeLBaseModule: """ Create a model with configuration scaled by the given index. Args: scaling_index: Multiplier for scaling variables Returns: EasyDeLBaseModule: Initialized model with scaled configuration """ return self.create_model( config=self.create_config(scaling_index=scaling_index), dtype=self.config_variables["dtype"], param_dtype=self.config_variables["param_dtype"], precision=self.config_variables["precision"], seed=self.config_variables["seed"], )
[docs] def create_trainer( self, arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, state: EasyDeLState | None = None, ) -> BaseTrainer | Trainer: """ Create a trainer instance for model training. Args: arguments: Training configuration and hyperparameters dataset_train: Training dataset dataset_eval: Optional evaluation dataset data_collator: Optional data collator for batching state: Model state to train Returns: BaseTrainer | Trainer: Configured trainer instance """ return self.trainer_module( arguments=arguments, dataset_train=dataset_train, dataset_eval=dataset_eval, data_collator=data_collator, model_state=state, )
[docs] def train( self, scaling_index: int, arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, model: EasyDeLBaseModule | None = None, state: EasyDeLState | None = None, ): """ Execute distributed training with the configured model. This method handles model/state initialization from various sources: 1. Provided state (highest priority) 2. Provided model (converted to state) 3. Checkpoint from bucket_path 4. New model creation with scaling_index Args: scaling_index: Multiplier for model scaling (used if creating new model) arguments: Training configuration dataset_train: Training dataset dataset_eval: Optional evaluation dataset data_collator: Optional data collator model: Optional pre-initialized model state: Optional pre-initialized state Returns: Training results from the underlying trainer Notes: - For automatic resume from interruptions, set: - arguments.resume_if_possible = True - arguments.save_directory = "path/to/checkpoints" - State sharding is handled by the trainer based on partition rules - Checkpoint loading respects the priority order above Raises: AssertionError: If no valid model state can be obtained """ if state is None and model is None: if self.bucket_path is not None: import easydel as ed state = self.state_class.load_state( load_directory=self.bucket_path, dtype=self.config_variables["dtype"], param_dtype=self.config_variables["param_dtype"], precision=self.config_variables["precision"], auto_shard_model=True, sharding_axis_names=self.config_variables["sharding_axis_names"], sharding_axis_dims=self.config_variables["sharding_axis_dims"], sharding_dcn_axis_dims=self.config_variables["sharding_dcn_axis_dims"], config_kwargs=ed.EasyDeLBaseConfigDict( freq_max_position_embeddings=self.config_variables["max_position_embeddings"], mask_max_position_embeddings=self.config_variables["max_position_embeddings"], attn_mechanism=self.config_variables["attn_mechanism"], attn_dtype=self.config_variables["attn_dtype"], attn_softmax_dtype=self.config_variables["attn_softmax_dtype"], gradient_checkpointing=self.config_variables["gradient_checkpointing"], ), partition_axis=self.config_variables["partition_axis"], ) else: logger.info(f"No model/state/checkpoint. Creating a new model (scaling_index={scaling_index}).") model = self.create_model_from_config(scaling_index=scaling_index) state = self.convert_model_to_state(model) elif model is not None and state is None: state = self.convert_model_to_state(model) assert state is not None, "Unable to obtain a valid model state." return self.create_trainer( arguments=arguments, dataset_train=dataset_train, dataset_eval=dataset_eval, data_collator=data_collator, state=state, ).train()
def __repr__(self): cls_name = self.__class__.__name__ items = [] for k, v in self.__dict__.items(): if not k.startswith("_"): try: s = str(v).replace("\n", "\n ") if len(s) > 200: s = f"{v.__class__.__name__}(...)" items.append(f" {k} : {s}") except TypeError: items.append(f" {k} : <unrepresentable>") return f"{cls_name}(\n" + "\n".join(items) + "\n)" __str__ = __repr__