easydel.trainers.ray_scaler.distributed_trainer#
Ray-based distributed trainer implementation for EasyDeL.
This module provides a distributed training implementation using Ray for scaling language model training across multiple GPUs and nodes. It integrates Ray’s distributed computing capabilities with EasyDeL’s training infrastructure to enable efficient large-scale model training.
The module includes: - RayDistributedTrainer: Main class for distributed training with Ray - Integration with Ray Train for distributed data loading and gradient synchronization - Support for both data and model parallelism strategies - Automatic resource management and fault tolerance - Checkpointing and recovery mechanisms for long-running training jobs
Key Components: - Automatic distribution of training data across workers - Gradient synchronization using Ray’s collective communication - Dynamic resource allocation and load balancing - Integration with Ray Tune for hyperparameter optimization - Support for heterogeneous hardware configurations
The trainer abstracts away the complexity of distributed training, allowing users to scale from single GPU to multi-node clusters with minimal code changes.
- class easydel.trainers.ray_scaler.distributed_trainer.RayDistributedConfig(*, pretrained_model_name_or_path: str, model_task: easydel.infra.factory.TaskType | None = None, model_type: str | None = None, offload_backend: str | None = None, config_scaling_variables: dict[str, int] | None = None, config_variables: dict[str, Any] | None = None)[source]#
Bases:
BaseModelConfiguration for RayDistributedTrainer that can be persisted to JSON.
This class handles serialization and deserialization of distributed training configurations, with special handling for JAX dtypes and PartitionAxis objects.
- pretrained_model_name_or_path#
Path or identifier for the pretrained model
- Type
str
- model_task#
The task type for the model (e.g., CAUSAL_LM, SEQ2SEQ)
- Type
- model_type#
The model architecture type (e.g., ‘llama’, ‘gpt2’)
- Type
str | None
- offload_backend#
Backend device for offloading (e.g., ‘cpu’, ‘gpu’)
- Type
str | None
- config_scaling_variables#
Variables to scale by scaling_index (e.g., hidden_size)
- Type
dict[str, int] | None
- config_variables#
Fixed configuration variables (e.g., dtype, precision)
- Type
dict[str, Any] | None
Notes
JAX dtype fields are converted to/from strings for JSON serialization
PartitionAxis objects are converted to/from dictionary representation
Use _saving_preprocess() before saving and _loading_postprocess() after loading
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- pretrained_model_name_or_path: str#
- class easydel.trainers.ray_scaler.distributed_trainer.RayDistributedTrainer(pretrained_model_name_or_path: str, bucket_path: str | None = None, model_task: easydel.infra.factory.TaskType | None = None, model_type: str | None = None, model_class: type[easydel.infra.base_module.EasyDeLBaseModule] | None = None, state_class: type[easydel.infra.base_state.EasyDeLState] | None = None, offload_backend: str | None = None, trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer] | None = None, config_scaling_variables: dict[str, int] | None = None, config_variables: dict[str, Any] | None = None)[source]#
Bases:
objectDistributed trainer for Ray-based training with EasyDeL models.
This class provides a lightweight wrapper for distributed training that: - Manages model configuration and scaling for different nodes - Handles model/state initialization and checkpoint loading - Delegates actual training to the underlying Trainer implementation
The trainer supports: - Dynamic model scaling based on scaling_index - Automatic tokenizer/processor setup with padding configuration - Flexible checkpoint loading from various sources - Integration with Ray for distributed training orchestration
Key Design Principles: - Resume logic is handled by BaseTrainer (set arguments.resume_if_possible=True) - State sharding is deferred to the main Trainer according to partition rules - Explicit checkpoint paths are used without automatic run-* resolution
- model_task#
The task type for the model (e.g., CAUSAL_LM)
- model_type#
The model architecture type (e.g., ‘llama’)
- Type
str
- model_class#
The EasyDeL model class to instantiate
- Type
- state_class#
The state class for model checkpointing
- Type
- offload_backend#
Backend for memory offloading
- Type
str
- trainer_module#
The trainer class to use for actual training
- CONFIG_SCALING_VARIABLES#
Variables that scale with scaling_index
- Type
ClassVar[dict[str, int]]
- CONFIG_VARIABLES#
Fixed configuration variables
- Type
ClassVar[dict[str, Any]]
- CONFIG_SCALING_VARIABLES: ClassVar[dict[str, int]] = {'hidden_size': 256, 'intermediate_size': 1024, 'moe_intermediate_size': 512, 'num_attention_heads': 2, 'num_key_value_heads': 1}#
- CONFIG_VARIABLES: ClassVar[dict[str, Any]] = {'attn_dtype': <class 'jax.numpy.bfloat16'>, 'attn_mechanism': 'auto', 'attn_softmax_dtype': <class 'jax.numpy.bfloat16'>, 'dtype': <class 'jax.numpy.bfloat16'>, 'gradient_checkpointing': EasyDeLGradientCheckPointers.NONE, 'initializer_range': 0.02, 'max_position_embeddings': 8192, 'param_dtype': <class 'jax.numpy.bfloat16'>, 'partition_axis': PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), 'precision': Precision.DEFAULT, 'seed': 654, 'sharding_axis_dims': (1, -1, 1, 1, 1), 'sharding_axis_names': ('dp', 'fsdp', 'ep', 'tp', 'sp'), 'sharding_dcn_axis_dims': (1, -1, 1, 1, 1)}#
- convert_model_to_state(model: EasyDeLBaseModule) EasyDeLState[source]#
Convert a model module to a state object.
- Parameters
model – The model to convert
- Returns
State object for checkpointing
- Return type
Notes
Does NOT perform sharding (handled by trainer)
Uses the configured state_class for conversion
- create_config(scaling_index: int) EasyDeLBaseConfig[source]#
Create a model configuration with scaled dimensions.
- Parameters
scaling_index – Multiplier for scaling variables (e.g., hidden_size)
- Returns
Configuration with scaled and fixed variables
- Return type
Notes
Scaling variables are multiplied by scaling_index
Fixed variables remain unchanged
Useful for creating different model sizes in distributed training
- create_model(config: ~easydel.infra.base_config.EasyDeLBaseConfig, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = None, seed: int = 684, lazy: bool = False) EasyDeLBaseModule[source]#
Create a model instance from configuration.
- Parameters
config – Model configuration
dtype – Computation dtype
param_dtype – Parameter dtype
precision – JAX precision setting
seed – Random seed for initialization
lazy – Whether to use lazy initialization (memory efficient)
- Returns
Initialized model instance
- Return type
- create_model_from_config(scaling_index: int) EasyDeLBaseModule[source]#
Create a model with configuration scaled by the given index.
- Parameters
scaling_index – Multiplier for scaling variables
- Returns
Initialized model with scaled configuration
- Return type
- create_trainer(arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, state: EasyDeLState | None = None) BaseTrainer | Trainer[source]#
Create a trainer instance for model training.
- Parameters
arguments – Training configuration and hyperparameters
dataset_train – Training dataset
dataset_eval – Optional evaluation dataset
data_collator – Optional data collator for batching
state – Model state to train
- Returns
Configured trainer instance
- Return type
- static extract_column_names(dataset: Dataset) list[str] | None[source]#
Extract column names from a dataset.
- Parameters
dataset – The dataset to extract column names from
- Returns
Column names if available, None otherwise
- Return type
list[str] | None
- classmethod from_config(path: str | os.PathLike, model_class: type[easydel.infra.base_module.EasyDeLBaseModule] | None = None, state_class: type[easydel.infra.base_state.EasyDeLState] | None = None, trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer] | None = None)[source]#
Create a RayDistributedTrainer from a saved configuration file.
- Parameters
path – Path to the JSON configuration file
model_class – Optional model class override
state_class – Optional state class override
trainer_module – Optional trainer module override
- Returns
Initialized trainer instance
- Return type
- load_processor() PreTrainedTokenizer[source]#
Load the tokenizer/processor for the model.
- Returns
Loaded tokenizer with padding configuration
- Return type
PreTrainedTokenizer
Notes
Automatically sets pad_token to eos_token if not defined
Logs a warning when falling back to eos_token for padding
- model_class: type[easydel.infra.base_module.EasyDeLBaseModule]#
- model_type: str#
- offload_backend: str#
- process_messages_data(messages: Any, max_length: int, padding_side: str = 'left') dict[str, jax.Array][source]#
Process chat messages using the tokenizer’s chat template.
- Parameters
messages – Chat messages to process
max_length – Maximum sequence length
padding_side – Side to pad sequences (‘left’ or ‘right’)
- Returns
Tokenized and padded inputs with flattened shapes
- Return type
dict[str, jax.Array]
- process_sample_data(sample: Any, max_length: int, padding_side: str = 'left') dict[str, jax.Array][source]#
Process a text sample into model inputs.
- Parameters
sample – Raw text sample to process
max_length – Maximum sequence length
padding_side – Side to pad sequences (‘left’ or ‘right’)
- Returns
Tokenized and padded inputs with flattened shapes
- Return type
dict[str, jax.Array]
- property processor: PreTrainedTokenizer#
Cached property for the tokenizer/processor.
- save_config(path: str | os.PathLike)[source]#
Save the current configuration to a JSON file.
- Parameters
path – Path where the configuration will be saved
- state_class: type[easydel.infra.base_state.EasyDeLState]#
- train(scaling_index: int, arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, model: EasyDeLBaseModule | None = None, state: EasyDeLState | None = None)[source]#
Execute distributed training with the configured model.
This method handles model/state initialization from various sources: 1. Provided state (highest priority) 2. Provided model (converted to state) 3. Checkpoint from bucket_path 4. New model creation with scaling_index
- Parameters
scaling_index – Multiplier for model scaling (used if creating new model)
arguments – Training configuration
dataset_train – Training dataset
dataset_eval – Optional evaluation dataset
data_collator – Optional data collator
model – Optional pre-initialized model
state – Optional pre-initialized state
- Returns
Training results from the underlying trainer
Notes
- For automatic resume from interruptions, set:
arguments.resume_if_possible = True
arguments.save_directory = “path/to/checkpoints”
State sharding is handled by the trainer based on partition rules
Checkpoint loading respects the priority order above
- Raises
AssertionError – If no valid model state can be obtained
- trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer]#