easydel.trainers.ray_scaler.distributed_trainer

Contents

easydel.trainers.ray_scaler.distributed_trainer#

Ray-based distributed trainer implementation for EasyDeL.

This module provides a distributed training implementation using Ray for scaling language model training across multiple GPUs and nodes. It integrates Ray’s distributed computing capabilities with EasyDeL’s training infrastructure to enable efficient large-scale model training.

The module includes: - RayDistributedTrainer: Main class for distributed training with Ray - Integration with Ray Train for distributed data loading and gradient synchronization - Support for both data and model parallelism strategies - Automatic resource management and fault tolerance - Checkpointing and recovery mechanisms for long-running training jobs

Key Components: - Automatic distribution of training data across workers - Gradient synchronization using Ray’s collective communication - Dynamic resource allocation and load balancing - Integration with Ray Tune for hyperparameter optimization - Support for heterogeneous hardware configurations

The trainer abstracts away the complexity of distributed training, allowing users to scale from single GPU to multi-node clusters with minimal code changes.

class easydel.trainers.ray_scaler.distributed_trainer.RayDistributedConfig(*, pretrained_model_name_or_path: str, model_task: easydel.infra.factory.TaskType | None = None, model_type: str | None = None, offload_backend: str | None = None, config_scaling_variables: dict[str, int] | None = None, config_variables: dict[str, Any] | None = None)[source]#

Bases: BaseModel

Configuration for RayDistributedTrainer that can be persisted to JSON.

This class handles serialization and deserialization of distributed training configurations, with special handling for JAX dtypes and PartitionAxis objects.

pretrained_model_name_or_path#

Path or identifier for the pretrained model

Type

str

model_task#

The task type for the model (e.g., CAUSAL_LM, SEQ2SEQ)

Type

easydel.infra.factory.TaskType | None

model_type#

The model architecture type (e.g., ‘llama’, ‘gpt2’)

Type

str | None

offload_backend#

Backend device for offloading (e.g., ‘cpu’, ‘gpu’)

Type

str | None

config_scaling_variables#

Variables to scale by scaling_index (e.g., hidden_size)

Type

dict[str, int] | None

config_variables#

Fixed configuration variables (e.g., dtype, precision)

Type

dict[str, Any] | None

Notes

  • JAX dtype fields are converted to/from strings for JSON serialization

  • PartitionAxis objects are converted to/from dictionary representation

  • Use _saving_preprocess() before saving and _loading_postprocess() after loading

config_scaling_variables: dict[str, int] | None#
config_variables: dict[str, tp.Any] | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_task: TaskType | None#
model_type: str | None#
offload_backend: str | None#
pretrained_model_name_or_path: str#
class easydel.trainers.ray_scaler.distributed_trainer.RayDistributedTrainer(pretrained_model_name_or_path: str, bucket_path: str | None = None, model_task: easydel.infra.factory.TaskType | None = None, model_type: str | None = None, model_class: type[easydel.infra.base_module.EasyDeLBaseModule] | None = None, state_class: type[easydel.infra.base_state.EasyDeLState] | None = None, offload_backend: str | None = None, trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer] | None = None, config_scaling_variables: dict[str, int] | None = None, config_variables: dict[str, Any] | None = None)[source]#

Bases: object

Distributed trainer for Ray-based training with EasyDeL models.

This class provides a lightweight wrapper for distributed training that: - Manages model configuration and scaling for different nodes - Handles model/state initialization and checkpoint loading - Delegates actual training to the underlying Trainer implementation

The trainer supports: - Dynamic model scaling based on scaling_index - Automatic tokenizer/processor setup with padding configuration - Flexible checkpoint loading from various sources - Integration with Ray for distributed training orchestration

Key Design Principles: - Resume logic is handled by BaseTrainer (set arguments.resume_if_possible=True) - State sharding is deferred to the main Trainer according to partition rules - Explicit checkpoint paths are used without automatic run-* resolution

model_task#

The task type for the model (e.g., CAUSAL_LM)

Type

easydel.infra.factory.TaskType

model_type#

The model architecture type (e.g., ‘llama’)

Type

str

model_class#

The EasyDeL model class to instantiate

Type

type[easydel.infra.base_module.EasyDeLBaseModule]

state_class#

The state class for model checkpointing

Type

type[easydel.infra.base_state.EasyDeLState]

offload_backend#

Backend for memory offloading

Type

str

trainer_module#

The trainer class to use for actual training

Type

type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer]

CONFIG_SCALING_VARIABLES#

Variables that scale with scaling_index

Type

ClassVar[dict[str, int]]

CONFIG_VARIABLES#

Fixed configuration variables

Type

ClassVar[dict[str, Any]]

CONFIG_SCALING_VARIABLES: ClassVar[dict[str, int]] = {'hidden_size': 256, 'intermediate_size': 1024, 'moe_intermediate_size': 512, 'num_attention_heads': 2, 'num_key_value_heads': 1}#
CONFIG_VARIABLES: ClassVar[dict[str, Any]] = {'attn_dtype': <class 'jax.numpy.bfloat16'>, 'attn_mechanism': 'auto', 'attn_softmax_dtype': <class 'jax.numpy.bfloat16'>, 'dtype': <class 'jax.numpy.bfloat16'>, 'gradient_checkpointing': EasyDeLGradientCheckPointers.NONE, 'initializer_range': 0.02, 'max_position_embeddings': 8192, 'param_dtype': <class 'jax.numpy.bfloat16'>, 'partition_axis': PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), 'precision': Precision.DEFAULT, 'seed': 654, 'sharding_axis_dims': (1, -1, 1, 1, 1), 'sharding_axis_names': ('dp', 'fsdp', 'ep', 'tp', 'sp'), 'sharding_dcn_axis_dims': (1, -1, 1, 1, 1)}#
convert_model_to_state(model: EasyDeLBaseModule) EasyDeLState[source]#

Convert a model module to a state object.

Parameters

model – The model to convert

Returns

State object for checkpointing

Return type

EasyDeLState

Notes

  • Does NOT perform sharding (handled by trainer)

  • Uses the configured state_class for conversion

create_config(scaling_index: int) EasyDeLBaseConfig[source]#

Create a model configuration with scaled dimensions.

Parameters

scaling_index – Multiplier for scaling variables (e.g., hidden_size)

Returns

Configuration with scaled and fixed variables

Return type

EasyDeLBaseConfig

Notes

  • Scaling variables are multiplied by scaling_index

  • Fixed variables remain unchanged

  • Useful for creating different model sizes in distributed training

create_model(config: ~easydel.infra.base_config.EasyDeLBaseConfig, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = None, seed: int = 684, lazy: bool = False) EasyDeLBaseModule[source]#

Create a model instance from configuration.

Parameters
  • config – Model configuration

  • dtype – Computation dtype

  • param_dtype – Parameter dtype

  • precision – JAX precision setting

  • seed – Random seed for initialization

  • lazy – Whether to use lazy initialization (memory efficient)

Returns

Initialized model instance

Return type

EasyDeLBaseModule

create_model_from_config(scaling_index: int) EasyDeLBaseModule[source]#

Create a model with configuration scaled by the given index.

Parameters

scaling_index – Multiplier for scaling variables

Returns

Initialized model with scaled configuration

Return type

EasyDeLBaseModule

create_trainer(arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, state: EasyDeLState | None = None) BaseTrainer | Trainer[source]#

Create a trainer instance for model training.

Parameters
  • arguments – Training configuration and hyperparameters

  • dataset_train – Training dataset

  • dataset_eval – Optional evaluation dataset

  • data_collator – Optional data collator for batching

  • state – Model state to train

Returns

Configured trainer instance

Return type

BaseTrainer | Trainer

static extract_column_names(dataset: Dataset) list[str] | None[source]#

Extract column names from a dataset.

Parameters

dataset – The dataset to extract column names from

Returns

Column names if available, None otherwise

Return type

list[str] | None

classmethod from_config(path: str | os.PathLike, model_class: type[easydel.infra.base_module.EasyDeLBaseModule] | None = None, state_class: type[easydel.infra.base_state.EasyDeLState] | None = None, trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer] | None = None)[source]#

Create a RayDistributedTrainer from a saved configuration file.

Parameters
  • path – Path to the JSON configuration file

  • model_class – Optional model class override

  • state_class – Optional state class override

  • trainer_module – Optional trainer module override

Returns

Initialized trainer instance

Return type

RayDistributedTrainer

load_processor() PreTrainedTokenizer[source]#

Load the tokenizer/processor for the model.

Returns

Loaded tokenizer with padding configuration

Return type

PreTrainedTokenizer

Notes

  • Automatically sets pad_token to eos_token if not defined

  • Logs a warning when falling back to eos_token for padding

model_class: type[easydel.infra.base_module.EasyDeLBaseModule]#
model_task: TaskType#
model_type: str#
offload_backend: str#
process_messages_data(messages: Any, max_length: int, padding_side: str = 'left') dict[str, jax.Array][source]#

Process chat messages using the tokenizer’s chat template.

Parameters
  • messages – Chat messages to process

  • max_length – Maximum sequence length

  • padding_side – Side to pad sequences (‘left’ or ‘right’)

Returns

Tokenized and padded inputs with flattened shapes

Return type

dict[str, jax.Array]

process_sample_data(sample: Any, max_length: int, padding_side: str = 'left') dict[str, jax.Array][source]#

Process a text sample into model inputs.

Parameters
  • sample – Raw text sample to process

  • max_length – Maximum sequence length

  • padding_side – Side to pad sequences (‘left’ or ‘right’)

Returns

Tokenized and padded inputs with flattened shapes

Return type

dict[str, jax.Array]

property processor: PreTrainedTokenizer#

Cached property for the tokenizer/processor.

save_config(path: str | os.PathLike)[source]#

Save the current configuration to a JSON file.

Parameters

path – Path where the configuration will be saved

state_class: type[easydel.infra.base_state.EasyDeLState]#
train(scaling_index: int, arguments: TrainingArguments, dataset_train: Dataset, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, model: EasyDeLBaseModule | None = None, state: EasyDeLState | None = None)[source]#

Execute distributed training with the configured model.

This method handles model/state initialization from various sources: 1. Provided state (highest priority) 2. Provided model (converted to state) 3. Checkpoint from bucket_path 4. New model creation with scaling_index

Parameters
  • scaling_index – Multiplier for model scaling (used if creating new model)

  • arguments – Training configuration

  • dataset_train – Training dataset

  • dataset_eval – Optional evaluation dataset

  • data_collator – Optional data collator

  • model – Optional pre-initialized model

  • state – Optional pre-initialized state

Returns

Training results from the underlying trainer

Notes

  • For automatic resume from interruptions, set:
    • arguments.resume_if_possible = True

    • arguments.save_directory = “path/to/checkpoints”

  • State sharding is handled by the trainer based on partition rules

  • Checkpoint loading respects the priority order above

Raises

AssertionError – If no valid model state can be obtained

trainer_module: type[easydel.trainers.base_trainer.BaseTrainer | easydel.trainers.trainer.trainer.Trainer]#