easydel.trainers.trainer_protocol#
- class easydel.trainers.trainer_protocol.BaseProgressBar[source]#
Bases:
ABCAbstract base class for progress bar implementations.
- class easydel.trainers.trainer_protocol.BaseTrainerProtocol(arguments: TrainingArguments, model: EasyDeLBaseModule, dataset_train: Optional[Any] = None, dataset_eval: Optional[Any] = None, finetune: bool = True, checkpoint_path: Optional[Union[str, PathLike]] = None)[source]#
Bases:
object- abstract apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to the model.
- arguments: TrainingArguments#
- checkpoint_manager: Any#
- checkpoint_path: Optional[Union[str, PathLike]]#
- config: EasyDeLBaseConfig#
- abstract configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configures the dataloaders for training and evaluation.
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- abstract configure_model() TrainerConfigureModelOutput[source]#
Configures the model, optimizer, scheduler, and configuration.
- abstract static count_model_parameters(prm)[source]#
Prints the number of model parameters in billions.
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
- abstract create_progress_bar(desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- data_collator: Optional[Callable]#
- dataset_eval: Optional[Any]#
- dataset_train: Optional[Any]#
- dtype: Any#
- abstract eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluates using the provided model state.
- evalu_tracker: CompilationTracker#
- abstract property evaluation_batch_size#
- finetune: bool#
- abstract get_runstage_flops(is_training: bool) float[source]#
Return the total number of FLOPs for the model.
- abstract log_metrics(metrics: Dict[str, float], pbar: BaseProgressBar, step: int, mode: str = 'train') None[source]#
Log metrics and update progress bar.
- abstract log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log distribution of weights.
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: Any#
- abstract property mesh#
- abstract property model#
- model_state: EasyDeLState#
- abstract on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- abstract on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
hook process to call in start of the step.
- param_dtype: Any#
- pruning_module: Any#
- abstract save_information(output_path: Union[str, Path]) None[source]#
Save the generated information to a markdown file.
- abstract save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, base_hf_auto_class=None, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
- sharded_evaluation_step_function: Callable#
- sharded_training_step_function: Callable#
- state: Any#
- state_named_sharding: Any#
- state_partition_spec: Any#
- state_shape: Any#
- abstract train(model_parameters: Optional[FrozenDict] = None, state: Optional[EasyDeLState] = None) Any[source]#
Train using the provided model state.
- train_tracker: CompilationTracker#
- abstract property training_batch_size#
- tx: GradientTransformation#
- wandb_runtime: Any#
- class easydel.trainers.trainer_protocol.JSONProgressBar(desc='')[source]#
Bases:
BaseProgressBarWrapper for JSON
- class easydel.trainers.trainer_protocol.MetricsColumn(metrics_to_show=None)[source]#
Bases:
ProgressColumnA custom progress column for displaying metrics.
- class easydel.trainers.trainer_protocol.MetricsTracker[source]#
Bases:
objectTracks and aggregates training metrics over time.
- class easydel.trainers.trainer_protocol.NullProgressBar[source]#
Bases:
BaseProgressBarDummy progress bar that does nothing - useful for multiprocessing.
- class easydel.trainers.trainer_protocol.RichProgressBar(progress: Progress, task_id: TaskID)[source]#
Bases:
BaseProgressBarWrapper for rich progress bar.
- class easydel.trainers.trainer_protocol.StepMetrics(arguments)[source]#
Bases:
objectHandles calculation and tracking of training metrics.
- calculate(metrics: LossMetrics, current_step: int, epoch: int, flops: float, batch_size: int, seq_length: int, learning_rate: float, mode: Optional[Literal['eval', 'train']] = None, **extras) Dict[str, float][source]#
Calculate comprehensive metrics for the training step.
- class easydel.trainers.trainer_protocol.TqdmProgressBar(pbar: tqdm)[source]#
Bases:
BaseProgressBarWrapper for tqdm progress bar.
- class easydel.trainers.trainer_protocol.TrainerConfigureDataloaderOutput(dataloader_train: 'tp.Iterator[np.ndarray]', max_training_steps: 'int', dataloader_eval: 'tp.Optional[tp.Iterator[np.ndarray]]' = None, max_evaluation_steps: 'tp.Optional[int]' = None)[source]#
Bases:
object- max_evaluation_steps: Optional[int] = None#
- max_training_steps: int#
- replace(**kwargs)#
- class easydel.trainers.trainer_protocol.TrainerConfigureFunctionOutput(sharded_training_step_function: 'JitWrapped', mesh: 'Mesh', checkpoint_manager: 'CheckpointManager', sharded_evaluation_step_function: 'tp.Optional[JitWrapped]' = None)[source]#
Bases:
object- checkpoint_manager: CheckpointManager#
- replace(**kwargs)#
- sharded_evaluation_step_function: Optional[Callable] = None#
- sharded_training_step_function: Callable#
- class easydel.trainers.trainer_protocol.TrainerConfigureModelOutput(model: 'EasyDeLBaseModule', tx: 'GradientTransformation', scheduler: 'Schedule', config: 'tp.Optional[EasyDeLBaseConfig]' = None)[source]#
Bases:
object- config: Optional[EasyDeLBaseConfig] = None#
- model: EasyDeLBaseModule#
- replace(**kwargs)#
- scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
- tx: GradientTransformation#
- class easydel.trainers.trainer_protocol.TrainerOutput(state: 'EasyDeLState', mesh: 'tp.Optional[jax.sharding.Mesh]', last_save_file_name: 'tp.Optional[str]' = None, checkpoint_path: 'tp.Optional[str]' = None)[source]#
Bases:
object- checkpoint_path: Optional[str] = None#
- last_save_file_name: Optional[str] = None#
- replace(**kwargs)#
- state: EasyDeLState#