easydel.trainers.trainer_protocol

Contents

easydel.trainers.trainer_protocol#

class easydel.trainers.trainer_protocol.BaseProgressBar[source]#

Bases: ABC

Abstract base class for progress bar implementations.

abstract close() None[source]#
abstract reset() None[source]#
abstract set_postfix(**kwargs) None[source]#
abstract update(n: int = 1) None[source]#
class easydel.trainers.trainer_protocol.BaseTrainerProtocol(arguments: TrainingArguments, model: EasyDeLBaseModule, dataset_train: Optional[Any] = None, dataset_eval: Optional[Any] = None, finetune: bool = True, checkpoint_path: Optional[Union[str, PathLike]] = None)[source]#

Bases: object

abstract apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to the model.

arguments: TrainingArguments#
checkpoint_manager: Any#
checkpoint_path: Optional[Union[str, PathLike]]#
abstract compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

config: EasyDeLBaseConfig#
abstract configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

abstract configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

abstract static count_model_parameters(prm)[source]#

Prints the number of model parameters in billions.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

abstract create_progress_bar(desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

data_collator: Optional[Callable]#
dataloader_eval: Optional[Iterator[ndarray]]#
dataloader_train: Iterator[ndarray]#
dataset_eval: Optional[Any]#
dataset_train: Optional[Any]#
dtype: Any#
abstract eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluates using the provided model state.

evalu_tracker: CompilationTracker#
abstract property evaluation_batch_size#
finetune: bool#
abstract finish()[source]#

Finalize the training process.

abstract get_runstage_flops(is_training: bool) float[source]#

Return the total number of FLOPs for the model.

abstract initialize_trainer_utils()[source]#

Initializes all trainer utilities.

abstract log_metrics(metrics: Dict[str, float], pbar: BaseProgressBar, step: int, mode: str = 'train') None[source]#

Log metrics and update progress bar.

abstract log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log distribution of weights.

max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: Any#
abstract property mesh#
abstract property model#
model_state: EasyDeLState#
abstract on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

abstract on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

param_dtype: Any#
pruning_module: Any#
abstract save_information(output_path: Union[str, Path]) None[source]#

Save the generated information to a markdown file.

abstract save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, base_hf_auto_class=None, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
sharded_evaluation_step_function: Callable#
sharded_training_step_function: Callable#
abstract specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

abstract start_evaluation_hook()[source]#

Hook to run before evaluation starts.

abstract start_training_hook()[source]#

Hook to run before training starts.

state: Any#
state_named_sharding: Any#
state_partition_spec: Any#
state_shape: Any#
timer: Timers#
abstract train(model_parameters: Optional[FrozenDict] = None, state: Optional[EasyDeLState] = None) Any[source]#

Train using the provided model state.

train_tracker: CompilationTracker#
abstract property training_batch_size#
tx: GradientTransformation#
wandb_runtime: Any#
class easydel.trainers.trainer_protocol.JSONProgressBar(desc='')[source]#

Bases: BaseProgressBar

Wrapper for JSON

close() None[source]#
reset() None[source]#
set_postfix(**kwargs) None[source]#
update(n: int = 1) None[source]#
class easydel.trainers.trainer_protocol.MetricsColumn(metrics_to_show=None)[source]#

Bases: ProgressColumn

A custom progress column for displaying metrics.

render(task: Task) Text[source]#

Render the metrics in a organized way.

class easydel.trainers.trainer_protocol.MetricsTracker[source]#

Bases: object

Tracks and aggregates training metrics over time.

reset(step)[source]#

Reset tracked metrics.

update(loss, accuracy, step)[source]#

Update tracked metrics with new values.

class easydel.trainers.trainer_protocol.NullProgressBar[source]#

Bases: BaseProgressBar

Dummy progress bar that does nothing - useful for multiprocessing.

close() None[source]#
reset() None[source]#
set_postfix(**kwargs) None[source]#
update(n: int = 1) None[source]#
class easydel.trainers.trainer_protocol.RichProgressBar(progress: Progress, task_id: TaskID)[source]#

Bases: BaseProgressBar

Wrapper for rich progress bar.

close() None[source]#
reset() None[source]#
set_postfix(**kwargs) None[source]#
update(n: int = 1) None[source]#
class easydel.trainers.trainer_protocol.StepMetrics(arguments)[source]#

Bases: object

Handles calculation and tracking of training metrics.

calculate(metrics: LossMetrics, current_step: int, epoch: int, flops: float, batch_size: int, seq_length: int, learning_rate: float, mode: Optional[Literal['eval', 'train']] = None, **extras) Dict[str, float][source]#

Calculate comprehensive metrics for the training step.

start_step()[source]#

Mark the start of a training step.

class easydel.trainers.trainer_protocol.TqdmProgressBar(pbar: tqdm)[source]#

Bases: BaseProgressBar

Wrapper for tqdm progress bar.

close() None[source]#
reset() None[source]#
set_postfix(**kwargs) None[source]#
update(n: int = 1) None[source]#
class easydel.trainers.trainer_protocol.TrainerConfigureDataloaderOutput(dataloader_train: 'tp.Iterator[np.ndarray]', max_training_steps: 'int', dataloader_eval: 'tp.Optional[tp.Iterator[np.ndarray]]' = None, max_evaluation_steps: 'tp.Optional[int]' = None)[source]#

Bases: object

dataloader_eval: Optional[Iterator[ndarray]] = None#
dataloader_train: Iterator[ndarray]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_evaluation_steps: Optional[int] = None#
max_training_steps: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.trainer_protocol.TrainerConfigureFunctionOutput(sharded_training_step_function: 'JitWrapped', mesh: 'Mesh', checkpoint_manager: 'CheckpointManager', sharded_evaluation_step_function: 'tp.Optional[JitWrapped]' = None)[source]#

Bases: object

checkpoint_manager: CheckpointManager#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

mesh: Mesh#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sharded_evaluation_step_function: Optional[Callable] = None#
sharded_training_step_function: Callable#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.trainer_protocol.TrainerConfigureModelOutput(model: 'EasyDeLBaseModule', tx: 'GradientTransformation', scheduler: 'Schedule', config: 'tp.Optional[EasyDeLBaseConfig]' = None)[source]#

Bases: object

config: Optional[EasyDeLBaseConfig] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

model: EasyDeLBaseModule#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

tx: GradientTransformation#
class easydel.trainers.trainer_protocol.TrainerOutput(state: 'EasyDeLState', mesh: 'tp.Optional[jax.sharding.Mesh]', last_save_file_name: 'tp.Optional[str]' = None, checkpoint_path: 'tp.Optional[str]' = None)[source]#

Bases: object

checkpoint_path: Optional[str] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

last_save_file_name: Optional[str] = None#
mesh: Optional[Mesh]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state: EasyDeLState#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.