easydel.trainers.base_trainer#
- class easydel.trainers.base_trainer.BaseTrainer(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | IterableDataset | ShardedDataSource | None = None, dataset_eval: Dataset | IterableDataset | ShardedDataSource | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, processing_class: PreTrainedTokenizerBase | None = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainerProtocolBase trainer class implementing core training functionality for EasyDeL models.
This class provides the foundation for training and evaluation workflows, including: - Checkpoint management and resumption - Dataloader configuration (Grain and TensorFlow datasets) - Model state initialization and sharding - Training and evaluation step compilation - Metrics logging and monitoring - Memory tracking and performance optimization
The trainer handles distributed training across multiple devices using JAX’s sharding capabilities and supports various optimization strategies.
- arguments#
Training configuration arguments
- model_state#
Current state of the model including parameters and optimizer state
- dataset_train#
Training dataset
- Type
Any | None
- dataset_eval#
Evaluation dataset
- Type
Any | None
- data_collator#
Function to collate batch data
- Type
Optional[Callable]
- finetune#
Whether this is a fine-tuning run
- Type
bool
- mesh#
Device mesh for distributed computation
- Type
Any
- checkpoint_manager#
Manager for saving/loading checkpoints
- Type
Any
- _train_source#
Internal ShardedDataSource for training data
- Type
easydel.data.core.protocols.ShardedDataSource | None
- _eval_source#
Internal ShardedDataSource for evaluation data
- Type
easydel.data.core.protocols.ShardedDataSource | None
- apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to check for issues and enforce limits.
- Parameters
metrics – Current training metrics including loss.
- Returns
Potentially modified metrics.
- Return type
- Raises
EasyDeLBreakRequest – If NaN loss detected and break_on_nan is True.
EasyDeLTimerError – If training time limit exceeded.
Note
Checks for NaN losses and training time limits based on configuration in training arguments.
- calculate_number_total_flops(params, is_training=True)[source]#
Calculate total FLOPs for the model.
- Parameters
params – Model parameters.
is_training – Whether calculating for training (includes backward pass).
- Returns
Total FLOPs count.
- Return type
int
- compile_aot() bool[source]#
Compile training and evaluation functions ahead-of-time.
This method performs AOT (Ahead-Of-Time) compilation of the training and evaluation step functions using JAX’s JIT compilation. This improves performance by compiling the functions once before the training loop.
- Returns
True if any functions were compiled, False otherwise
- Return type
bool
Note
Compilation happens automatically on first call if not done AOT
AOT compilation can reduce first-step latency
Uses actual data batches to determine compilation shapes
- configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configures the dataloaders for training and evaluation.
This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.
- Returns
- An object containing the configured dataloaders and the
maximum number of training and evaluation steps.
- Return type
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- configure_model() TrainerConfigureModelOutput[source]#
Configures the model, optimizer, scheduler, and configuration.
This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.
- Returns
- An object containing the configured
model, optimizer, scheduler, and configuration.
- Return type
- static count_model_parameters(prm)[source]#
Count total number of model parameters.
- Parameters
prm – Model parameters (can be frozen or unfrozen).
- Returns
Total number of parameters.
- Return type
int
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_generate_function(generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, *, shard_inputs: bool = True, config_overrides: dict[str, Any] | None = None, **generate_kwargs) Callable[[EasyDeLState, Array, Array], tuple[jax.Array, jax.Array, jax.Array]][source]#
Build and return a compiled generation function that mirrors the model’s generate.
- Parameters
generation_config – Optional generation configuration. If omitted, the model’s default configuration is used.
shard_inputs – Whether to shard the prompt tensors using the model’s partition manager before generation.
config_overrides – Optional attribute overrides applied to the copied generation configuration.
generate_kwargs – Extra keyword arguments forwarded to module.generate.
- abstract create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- abstract create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- property evaluation_batch_size#
Get the evaluation batch size.
- Returns
The batch size used for evaluation
- static finish()[source]#
Clean up resources and finish any active logging sessions.
Notes
Currently only finishes wandb session if active. Safe to call even if wandb is not initialized.
- generate_aio(input_ids: jax.Array | numpy.ndarray, attention_mask: jax.Array | numpy.ndarray | None = None, *, state: easydel.infra.base_state.EasyDeLState | None = None, generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, shard_inputs: bool | None = None, config_overrides: dict[str, Any] | None = None, return_metadata: bool = False, all_gather: bool = False, **generate_kwargs)[source]#
Convenience wrapper around the compiled generation function.
- generate_unified(input_ids: jax.Array | numpy.ndarray | None = None, attention_mask: jax.Array | numpy.ndarray | None = None, prompts: str | list[str] | None = None, *, state: easydel.infra.base_state.EasyDeLState | None = None, use_esurge: bool | None = None, apply_chat_template: bool = False, generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, shard_inputs: bool | None = None, config_overrides: dict[str, Any] | None = None, all_gather: bool = False, **generate_kwargs) GenerationResults[source]#
Unified generation interface supporting both compiled and eSurge generation.
This method provides a single interface for generation that automatically handles: - Conversion between text prompts and token IDs - Selection between eSurge and compiled generation based on configuration - Consistent output format regardless of generation method - Optional chat template application
- Parameters
input_ids – Optional token IDs for the prompt. If None, must provide prompts.
attention_mask – Optional attention mask for input_ids.
prompts – Optional text prompt(s). If None, must provide input_ids.
state – Model state to use for generation. Defaults to self.model_state.
use_esurge – Whether to use eSurge generation. Defaults to self.arguments.use_esurge_generation.
apply_chat_template – Whether to apply chat template to prompts. Default False. If True and prompts is a string, wraps it in [{“role”: “user”, “content”: prompt}].
generation_config – Optional generation configuration.
shard_inputs – Whether to shard inputs across devices.
config_overrides – Optional overrides for generation config.
all_gather – Whether to gather results from all devices.
**generate_kwargs – Additional kwargs passed to generation.
- Returns
- NamedTuple containing:
generation_results: The result text
prompt_ids: Token IDs for the prompt
prompt_mask: Attention mask for the prompt
sequences: Complete generated sequences (including prompt)
- Return type
- Raises
ValueError – If neither input_ids nor prompts are provided.
Example
>>> # Generate from token IDs (GRPO-style, no chat template) >>> results = trainer.generate_unified( ... input_ids=prompt_ids, ... attention_mask=mask, ... apply_chat_template=False # Raw generation ... ) >>> >>> # Generate with chat template (preview generation) >>> results = trainer.generate_unified( ... prompts="Explain quantum computing", ... apply_chat_template=True # Applies chat template ... ) >>> >>> # eSurge with chat format >>> results = trainer.generate_unified( ... prompts=[{"role": "user", "content": "Hello"}], ... use_esurge=True ... )
- initialize_trainer_utils()[source]#
Initializes various utilities used by the trainer.
This method orchestrates the initialization of all trainer components in the correct order. It sets up: 1. Weights & Biases logging (if enabled) 2. Training timer for performance monitoring 3. Dataloaders for training and evaluation 4. Model, optimizer, and learning rate scheduler 5. Model state sharding across devices 6. Compiled training and evaluation functions
The initialization order is important as later steps depend on earlier ones. For example, the optimizer configuration depends on the number of training steps determined during dataloader configuration.
- property is_enable#
Check if operations are enabled for this process.
- Returns
True if operations are enabled, False if restricted to main process only
Notes
When process_zero_is_admin is True, only the main process will have operations enabled.
- property is_process_zero#
Check if this is the main process (rank 0).
- Returns
True if this is the main process, False otherwise
- log_metrics(metrics: dict[str, float | list | tuple | numpy.ndarray | Any], pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#
Log metrics and update progress bar.
- log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log weight distribution statistics.
- Parameters
state – Model state containing parameters.
step – Current training step.
Notes
Delegates to arguments.log_weight_distribution method.
- maybe_generate(state: EasyDeLState, step: int, metrics: dict[str, float | list | tuple | numpy.ndarray | Any] | None = None) None[source]#
Optionally run preview generation to monitor training progress.
Uses generate_unified for consistent generation across both eSurge and compiled modes.
- property mesh#
Get the device mesh for distributed computation.
- Returns
The device mesh used for sharding computations
- property model#
Get the model instance.
- Returns
The model instance used for training
- on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#
Hook method called at the end of each training step.
- Parameters
state (EasyDeLState) – The current model state
metrics (MetricsType) – The metrics computed for this step
step (int) – The current training step number
- Returns
The potentially modified model state and metrics
- Return type
tuple[EasyDeLState, MetricsType]
Notes
This method can be overridden in subclasses to implement custom logic at the end of each training step, such as custom logging or state modifications.
- on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
Hook method called at the start of each training step.
- Parameters
state (EasyDeLState) – The current model state
step (int) – The current training step number
- Returns
The potentially modified model state
- Return type
Notes
This method can be overridden in subclasses to implement custom logic at the beginning of each training step.
- save_information(output_path: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath) None[source]#
Save the generated information to a markdown file.
- Parameters
output_path – Path where the markdown file should be saved
- save_pretrained(state: EasyDeLState, save_directory: str | None = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: dict | None = None, torch_save_pretrained_kwargs: dict | None = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- specs_to_name_sharding(tree, mesh=None)[source]#
Convert partition specs to named sharding.
- Parameters
tree – PyTree structure with partition specs.
mesh – Device mesh for sharding (uses trainer’s mesh if None).
- Returns
PyTree with named sharding specifications.
- start_evaluation_hook()[source]#
Hook called at the start of evaluation.
Notes
Sets up static metrics and records evaluation start time.
- start_training_hook()[source]#
Hook called at the start of training.
Notes
Sets up static metrics and records training start time.
- property training_batch_size#
Calculate the effective training batch size.
- Returns
The effective batch size including gradient accumulation
Notes
The effective batch size is calculated as: total_batch_size * gradient_accumulation_steps
- class easydel.trainers.base_trainer.GenerationResults(generation_results: str | list[str], prompt_ids: jax.Array, prompt_mask: jax.Array, sequences: jax.Array, completion_ids: jax.Array, completion_mask: jax.Array, decoded_prompts: str | list[str], completion_prompts: list[str | list[dict[str, str]]] | None = None)[source]#
Bases:
NamedTupleResults from unified generation containing both text and token representations.
- generation_results#
The generation results from engine
- Type
str | list[str]
- sequences#
Complete generated sequences including prompt (batch_size, max_seq_len + max_new_tokens)
- Type
- completion_ids#
Token IDs for only the generated completions (batch_size, max_new_tokens) - right-padded
- Type
- completion_prompts#
Optional prompt objects (text or chat dicts) aligned one-to-one with completions.
- Type
list[str | list[dict[str, str]]] | None
- decoded_prompts: str | list[str]#
Alias for field number 6
- generation_results: str | list[str]#
Alias for field number 0