easydel.trainers.trainer.__init__#

class easydel.trainers.trainer.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainer

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluates the model using the provided model state.

This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.

Parameters

model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.

Yields

Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.

Raises

AssertionError – If the evaluation dataloader is not set.

train() TrainerOutput[source]#

Executes the complete training process.

This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.

Returns

An object containing the final training state, metrics, and any additional outputs.

Return type

TrainerOutput