easydel.trainers.base_trainer#

class easydel.trainers.base_trainer.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainerProtocol

apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to the model.

calculate_number_total_flops(params, is_training=True)[source]#
compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.

Returns

An object containing the configured dataloaders and the

maximum number of training and evaluation steps.

Return type

TrainerConfigureDataloaderOutput

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.

Returns

An object containing the configured model, optimizer, scheduler, and configuration.

Return type

TrainerConfigureModelOutput

static count_model_parameters(prm)[source]#

Prints the number of model parameters in billions.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

property evaluation_batch_size#
static finish()[source]#

Finalize the training process.

get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#

Return the total number of FLOPs for the model.

initialize_trainer_utils()[source]#

Initializes various utilities used by the trainer.

This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.

property is_process_zero#
log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#

Log metrics and update progress bar.

log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log distribution of weights.

property mesh#
property model#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

save_information(output_path: Union[str, Path]) None[source]#

Save the generated information to a markdown file.

Parameters

output_path – Path where the markdown file should be saved

save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

start_evaluation_hook()[source]#

Hook to run before evaluation starts.

start_training_hook()[source]#

Hook to run before training starts.

property training_batch_size#