easydel.trainers.trainer.trainer#
- class easydel.trainers.trainer.trainer.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainer- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluates the model using the provided model state.
This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.
- Parameters
model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.
- Yields
Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.
- Raises
AssertionError – If the evaluation dataloader is not set.
- train() TrainerOutput[source]#
Executes the complete training process.
This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.
- Returns
An object containing the final training state, metrics, and any additional outputs.
- Return type