easydel.trainers.trainer_protocol#
- class easydel.trainers.trainer_protocol.BaseTrainerProtocol(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | None = None, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, **deprecated_kwargs)[source]#
Bases:
objectAbstract base protocol defining the interface for all trainer implementations.
This protocol ensures that all trainer implementations provide the necessary methods and attributes for training and evaluation workflows. It defines the contract that concrete trainer classes must fulfill.
The protocol covers: - Initialization and configuration methods - Training and evaluation loops - Checkpoint management - Metrics logging and monitoring - Hook methods for customization
All methods marked with @abstractmethod must be implemented by subclasses.
- abstract apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to check for issues and enforce limits.
- Parameters
metrics – Current training metrics including loss.
- Returns
Potentially modified metrics.
- Return type
- Raises
EasyDeLBreakRequest – If NaN loss detected and break_on_nan is True.
EasyDeLTimerError – If training time limit exceeded.
Note
Checks for NaN losses and training time limits based on configuration in training arguments.
- arguments: TrainingArguments#
- abstract calculate_number_total_flops(params, is_training=True)[source]#
Calculate total FLOPs for the model.
- checkpoint_manager: Any#
- config: EasyDeLBaseConfig#
- abstract configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configure dataloaders for training and evaluation.
Creates training and evaluation dataloaders using the provided datasets and data collator. Determines the maximum number of training and evaluation steps based on dataset sizes and arguments.
- Returns
dataloader_train: Training data iterator
max_training_steps: Total training steps
dataloader_eval: Optional evaluation data iterator
max_evaluation_steps: Optional total evaluation steps
- Return type
TrainerConfigureDataloaderOutput containing
Note
Automatically selects between Grain and TensorFlow dataloaders based on arguments.use_grain setting.
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configure and JIT-compile training and evaluation step functions.
This method prepares the computational graph for training and evaluation, including setting up sharding specifications, compiling functions with appropriate static arguments, and initializing the checkpoint manager.
- Returns
Compiled functions and infrastructure
- Return type
- abstract configure_model() TrainerConfigureModelOutput[source]#
Configure model, optimizer, scheduler, and configuration.
Retrieves model configuration from the model state and creates the optimizer and scheduler using training arguments.
- Returns
model: The EasyDeL model instance
tx: Gradient transformation (optimizer)
scheduler: Learning rate schedule
config: Optional model configuration
- Return type
TrainerConfigureModelOutput containing
Note
If pruning_module is set, it wraps the optimizer for structured pruning support.
- abstract static count_model_parameters(prm)[source]#
Count total number of model parameters.
- Parameters
prm – Model parameters (can be frozen or unfrozen PyTree).
- Returns
Total number of parameters in the model.
- Return type
int
Note
Handles both frozen and unfrozen Flax parameter dictionaries.
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Create a generic data collection function for batch processing.
- Parameters
max_sequence_length – Maximum allowed sequence length for padding/truncation.
truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence
- Returns
Callable that processes batches of data.
Note
This is a generic version that can be used with any dataloader type. Implementations should handle padding, truncation, and format conversion.
- abstract create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Create a Grain data collection function for batch processing.
- Parameters
max_sequence_length – Maximum allowed sequence length for padding/truncation.
truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence
- Returns
Callable that processes batches for Grain dataloader.
Note
Function should handle padding, truncation, and data format conversion compatible with Grain’s data pipeline.
- abstract create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- Parameters
total – Total number of steps for the progress bar.
desc – Description text to display.
disabled – Whether to disable the progress bar.
- Returns
Progress bar instance of the configured type.
- Return type
Note
Type is determined by arguments.progress_bar_type: - “tqdm”: Standard tqdm progress bar - “rich”: Rich library progress bar with metrics - “json”: JSON-formatted progress output - disabled=True returns NullProgressBar
- abstract create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Create a TensorFlow Dataset collection function for batch processing.
- Parameters
max_sequence_length – Maximum allowed sequence length for padding/truncation.
truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence
- Returns
Callable that processes batches for TensorFlow datasets.
Note
Function should handle padding, truncation, and data format conversion compatible with tf.data API.
- data_collator: Optional[Callable]#
- dtype: Any#
- abstract eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluate the model on the evaluation dataset.
This method runs the model in evaluation mode, computing metrics without updating parameters. It yields metrics for each evaluation step, allowing for streaming evaluation and progress monitoring.
- Parameters
model_state – The model state to evaluate
- Yields
dict – Evaluation metrics for each step
- evalu_tracker: CompilationTracker#
- abstract property evaluation_batch_size#
- finetune: bool#
- abstract initialize_trainer_utils()[source]#
Initialize all trainer utilities in the correct order.
This orchestration method sets up all trainer components: 1. Weights & Biases logging (if enabled) 2. Training timer for performance monitoring 3. Dataloaders for training and evaluation 4. Model, optimizer, and learning rate scheduler 5. Model state sharding across devices 6. Compiled training and evaluation functions
Note
The initialization order is critical as later steps depend on earlier ones being completed.
- abstract property is_enable#
- abstract property is_process_zero#
- abstract log_metrics(metrics: dict[str, float], pbar: BaseProgressBar, step: int, mode: str = 'train') None[source]#
Log metrics and update progress bar.
- Parameters
metrics – Dictionary of metric names and values.
pbar – Progress bar instance to update.
step – Current step number.
mode – “train” or “eval” to prefix metrics.
Note
Updates progress bar every log_steps
Logs to wandb/tensorboard every report_steps
Filters out internal metrics (mlperf, grad_norm)
- abstract log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log weight distribution statistics.
- Parameters
state – Model state containing parameters.
step – Current training step.
Note
Logs statistics like mean, std, min, max of weights for monitoring training stability.
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: Any#
- abstract property mesh#
- abstract property model#
- model_state: EasyDeLState#
- abstract on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#
hook process to call in start of the step.
- abstract on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
hook process to call in start of the step.
- param_dtype: Any#
- pruning_module: Any#
- abstract save_information(output_path: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath) None[source]#
Save the generated information to a markdown file.
- abstract save_pretrained(state: EasyDeLState, save_directory: str | None = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: dict | None = None, torch_save_pretrained_kwargs: dict | None = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
- sharded_evaluation_step_function: Any#
- sharded_training_step_function: Any#
- state: Any#
- state_named_sharding: Any#
- state_partition_spec: Any#
- state_shape: Any#
- state_shardings: Any#
- abstract train() Any[source]#
Execute the complete training process.
This is the main entry point for training. It orchestrates the entire training workflow including initialization, training loops, evaluation, checkpointing, and finalization.
- Returns
TrainerOutput or similar object containing final state and metrics
- train_tracker: CompilationTracker#
- abstract property training_batch_size#
- tx: GradientTransformation#
- wandb_runtime: Any#
- class easydel.trainers.trainer_protocol.TrainerConfigureDataloaderOutput(dataloader_train: Iterator[ndarray], max_training_steps: int, dataloader_eval: Optional[Iterator[ndarray]] = None, max_evaluation_steps: int | None = None)[source]#
Bases:
objectOutput configuration for dataloader setup.
Contains the configured dataloaders and computed maximum steps for training and evaluation phases.
- dataloader_train#
Iterator over training batches
- Type
Iterator[numpy.ndarray]
- max_training_steps#
Total number of training steps
- Type
int
- dataloader_eval#
Optional iterator over evaluation batches
- Type
Optional[Iterator[numpy.ndarray]]
- max_evaluation_steps#
Optional total number of evaluation steps
- Type
int | None
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_training_steps: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.trainer_protocol.TrainerConfigureFunctionOutput(sharded_training_step_function: Any, mesh: Mesh, checkpoint_manager: AsyncCheckpointManager, sharded_evaluation_step_function: Any | None = None)[source]#
Bases:
objectOutput configuration for training and evaluation functions.
Contains the compiled step functions and supporting infrastructure.
- sharded_training_step_function#
JIT-compiled training step function
- Type
Any
- mesh#
Device mesh for distributed computation
- Type
- checkpoint_manager#
Manager for saving/loading checkpoints
- Type
eformer.serialization.async_manager.AsyncCheckpointManager
- sharded_evaluation_step_function#
Optional JIT-compiled evaluation function
- Type
Any | None
- checkpoint_manager: AsyncCheckpointManager#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sharded_training_step_function: Any#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.trainer_protocol.TrainerConfigureModelOutput(model: EasyDeLBaseModule, tx: GradientTransformation, scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]], config: easydel.infra.base_config.EasyDeLBaseConfig | None = None)[source]#
Bases:
objectOutput configuration for model setup.
Contains the configured model, optimizer, scheduler, and model configuration.
- model#
The initialized EasyDeL model
- tx#
Gradient transformation (optimizer) for training
- Type
optax._src.base.GradientTransformation
- scheduler#
Learning rate schedule function
- Type
collections.abc.Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]
- config#
Optional model configuration object
- Type
- config: easydel.infra.base_config.EasyDeLBaseConfig | None = None#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- model: EasyDeLBaseModule#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tx: GradientTransformation#
- class easydel.trainers.trainer_protocol.TrainerOutput(state: EasyDeLState, mesh: jax._src.mesh.Mesh | None, last_save_file_name: str | None = None, checkpoint_path: str | None = None)[source]#
Bases:
objectFinal output from the training process.
Contains the final model state and checkpoint information.
- state#
Final model state after training
- mesh#
Device mesh used during training
- Type
jax._src.mesh.Mesh | None
- last_save_file_name#
Name of the last saved checkpoint file
- Type
str | None
- checkpoint_path#
Full path to the last checkpoint
- Type
str | None
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- mesh: jax._src.mesh.Mesh | None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- state: EasyDeLState#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.