easydel.trainers.base_trainer#
- class easydel.trainers.base_trainer.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainerProtocol- apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to the model.
- configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configures the dataloaders for training and evaluation.
This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.
- Returns
- An object containing the configured dataloaders and the
maximum number of training and evaluation steps.
- Return type
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- configure_model() TrainerConfigureModelOutput[source]#
Configures the model, optimizer, scheduler, and configuration.
This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.
- Returns
An object containing the configured model, optimizer, scheduler, and configuration.
- Return type
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- property evaluation_batch_size#
- get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#
Return the total number of FLOPs for the model.
- initialize_trainer_utils()[source]#
Initializes various utilities used by the trainer.
This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.
- property is_process_zero#
- log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#
Log metrics and update progress bar.
- log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log distribution of weights.
- property mesh#
- property model#
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
hook process to call in start of the step.
- save_information(output_path: Union[str, Path]) None[source]#
Save the generated information to a markdown file.
- Parameters
output_path – Path where the markdown file should be saved
- save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- property training_batch_size#