easydel.trainers.base_trainer

Contents

easydel.trainers.base_trainer#

class easydel.trainers.base_trainer.BaseTrainer(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | IterableDataset | ShardedDataSource | None = None, dataset_eval: Dataset | IterableDataset | ShardedDataSource | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, processing_class: PreTrainedTokenizerBase | None = None, **deprecated_kwargs)[source]#

Bases: BaseTrainerProtocol

Base trainer class implementing core training functionality for EasyDeL models.

This class provides the foundation for training and evaluation workflows, including: - Checkpoint management and resumption - Dataloader configuration (Grain and TensorFlow datasets) - Model state initialization and sharding - Training and evaluation step compilation - Metrics logging and monitoring - Memory tracking and performance optimization

The trainer handles distributed training across multiple devices using JAX’s sharding capabilities and supports various optimization strategies.

arguments#

Training configuration arguments

Type

easydel.trainers.training_configurations.TrainingArguments

model_state#

Current state of the model including parameters and optimizer state

Type

easydel.infra.base_state.EasyDeLState

dataset_train#

Training dataset

Type

Any | None

dataset_eval#

Evaluation dataset

Type

Any | None

data_collator#

Function to collate batch data

Type

Optional[Callable]

finetune#

Whether this is a fine-tuning run

Type

bool

mesh#

Device mesh for distributed computation

Type

Any

checkpoint_manager#

Manager for saving/loading checkpoints

Type

Any

_train_source#

Internal ShardedDataSource for training data

Type

easydel.data.core.protocols.ShardedDataSource | None

_eval_source#

Internal ShardedDataSource for evaluation data

Type

easydel.data.core.protocols.ShardedDataSource | None

apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to check for issues and enforce limits.

Parameters

metrics – Current training metrics including loss.

Returns

Potentially modified metrics.

Return type

LossMetrics

Raises

Note

Checks for NaN losses and training time limits based on configuration in training arguments.

calculate_number_total_flops(params, is_training=True)[source]#

Calculate total FLOPs for the model.

Parameters
  • params – Model parameters.

  • is_training – Whether calculating for training (includes backward pass).

Returns

Total FLOPs count.

Return type

int

compile_aot() bool[source]#

Compile training and evaluation functions ahead-of-time.

This method performs AOT (Ahead-Of-Time) compilation of the training and evaluation step functions using JAX’s JIT compilation. This improves performance by compiling the functions once before the training loop.

Returns

True if any functions were compiled, False otherwise

Return type

bool

Note

  • Compilation happens automatically on first call if not done AOT

  • AOT compilation can reduce first-step latency

  • Uses actual data batches to determine compilation shapes

configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.

Returns

An object containing the configured dataloaders and the

maximum number of training and evaluation steps.

Return type

TrainerConfigureDataloaderOutput

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.

Returns

An object containing the configured

model, optimizer, scheduler, and configuration.

Return type

TrainerConfigureModelOutput

static count_model_parameters(prm)[source]#

Count total number of model parameters.

Parameters

prm – Model parameters (can be frozen or unfrozen).

Returns

Total number of parameters.

Return type

int

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_generate_function(generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, *, shard_inputs: bool = True, config_overrides: dict[str, Any] | None = None, **generate_kwargs) Callable[[EasyDeLState, Array, Array], tuple[jax.Array, jax.Array, jax.Array]][source]#

Build and return a compiled generation function that mirrors the model’s generate.

Parameters
  • generation_config – Optional generation configuration. If omitted, the model’s default configuration is used.

  • shard_inputs – Whether to shard the prompt tensors using the model’s partition manager before generation.

  • config_overrides – Optional attribute overrides applied to the copied generation configuration.

  • generate_kwargs – Extra keyword arguments forwarded to module.generate.

abstract create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

abstract create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

property evaluation_batch_size#

Get the evaluation batch size.

Returns

The batch size used for evaluation

static finish()[source]#

Clean up resources and finish any active logging sessions.

Notes

Currently only finishes wandb session if active. Safe to call even if wandb is not initialized.

generate_aio(input_ids: jax.Array | numpy.ndarray, attention_mask: jax.Array | numpy.ndarray | None = None, *, state: easydel.infra.base_state.EasyDeLState | None = None, generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, shard_inputs: bool | None = None, config_overrides: dict[str, Any] | None = None, return_metadata: bool = False, all_gather: bool = False, **generate_kwargs)[source]#

Convenience wrapper around the compiled generation function.

generate_unified(input_ids: jax.Array | numpy.ndarray | None = None, attention_mask: jax.Array | numpy.ndarray | None = None, prompts: str | list[str] | None = None, *, state: easydel.infra.base_state.EasyDeLState | None = None, use_esurge: bool | None = None, apply_chat_template: bool = False, generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None, shard_inputs: bool | None = None, config_overrides: dict[str, Any] | None = None, all_gather: bool = False, **generate_kwargs) GenerationResults[source]#

Unified generation interface supporting both compiled and eSurge generation.

This method provides a single interface for generation that automatically handles: - Conversion between text prompts and token IDs - Selection between eSurge and compiled generation based on configuration - Consistent output format regardless of generation method - Optional chat template application

Parameters
  • input_ids – Optional token IDs for the prompt. If None, must provide prompts.

  • attention_mask – Optional attention mask for input_ids.

  • prompts – Optional text prompt(s). If None, must provide input_ids.

  • state – Model state to use for generation. Defaults to self.model_state.

  • use_esurge – Whether to use eSurge generation. Defaults to self.arguments.use_esurge_generation.

  • apply_chat_template – Whether to apply chat template to prompts. Default False. If True and prompts is a string, wraps it in [{“role”: “user”, “content”: prompt}].

  • generation_config – Optional generation configuration.

  • shard_inputs – Whether to shard inputs across devices.

  • config_overrides – Optional overrides for generation config.

  • all_gather – Whether to gather results from all devices.

  • **generate_kwargs – Additional kwargs passed to generation.

Returns

NamedTuple containing:
  • generation_results: The result text

  • prompt_ids: Token IDs for the prompt

  • prompt_mask: Attention mask for the prompt

  • sequences: Complete generated sequences (including prompt)

Return type

GenerationResults

Raises

ValueError – If neither input_ids nor prompts are provided.

Example

>>> # Generate from token IDs (GRPO-style, no chat template)
>>> results = trainer.generate_unified(
...     input_ids=prompt_ids,
...     attention_mask=mask,
...     apply_chat_template=False  # Raw generation
... )
>>>
>>> # Generate with chat template (preview generation)
>>> results = trainer.generate_unified(
...     prompts="Explain quantum computing",
...     apply_chat_template=True  # Applies chat template
... )
>>>
>>> # eSurge with chat format
>>> results = trainer.generate_unified(
...     prompts=[{"role": "user", "content": "Hello"}],
...     use_esurge=True
... )
initialize_trainer_utils()[source]#

Initializes various utilities used by the trainer.

This method orchestrates the initialization of all trainer components in the correct order. It sets up: 1. Weights & Biases logging (if enabled) 2. Training timer for performance monitoring 3. Dataloaders for training and evaluation 4. Model, optimizer, and learning rate scheduler 5. Model state sharding across devices 6. Compiled training and evaluation functions

The initialization order is important as later steps depend on earlier ones. For example, the optimizer configuration depends on the number of training steps determined during dataloader configuration.

property is_enable#

Check if operations are enabled for this process.

Returns

True if operations are enabled, False if restricted to main process only

Notes

When process_zero_is_admin is True, only the main process will have operations enabled.

property is_process_zero#

Check if this is the main process (rank 0).

Returns

True if this is the main process, False otherwise

log_metrics(metrics: dict[str, float | list | tuple | numpy.ndarray | Any], pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#

Log metrics and update progress bar.

log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log weight distribution statistics.

Parameters
  • state – Model state containing parameters.

  • step – Current training step.

Notes

Delegates to arguments.log_weight_distribution method.

maybe_generate(state: EasyDeLState, step: int, metrics: dict[str, float | list | tuple | numpy.ndarray | Any] | None = None) None[source]#

Optionally run preview generation to monitor training progress.

Uses generate_unified for consistent generation across both eSurge and compiled modes.

property mesh#

Get the device mesh for distributed computation.

Returns

The device mesh used for sharding computations

property model#

Get the model instance.

Returns

The model instance used for training

on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#

Hook method called at the end of each training step.

Parameters
  • state (EasyDeLState) – The current model state

  • metrics (MetricsType) – The metrics computed for this step

  • step (int) – The current training step number

Returns

The potentially modified model state and metrics

Return type

tuple[EasyDeLState, MetricsType]

Notes

This method can be overridden in subclasses to implement custom logic at the end of each training step, such as custom logging or state modifications.

on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

Hook method called at the start of each training step.

Parameters
  • state (EasyDeLState) – The current model state

  • step (int) – The current training step number

Returns

The potentially modified model state

Return type

EasyDeLState

Notes

This method can be overridden in subclasses to implement custom logic at the beginning of each training step.

save_information(output_path: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath) None[source]#

Save the generated information to a markdown file.

Parameters

output_path – Path where the markdown file should be saved

save_pretrained(state: EasyDeLState, save_directory: str | None = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: dict | None = None, torch_save_pretrained_kwargs: dict | None = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

specs_to_name_sharding(tree, mesh=None)[source]#

Convert partition specs to named sharding.

Parameters
  • tree – PyTree structure with partition specs.

  • mesh – Device mesh for sharding (uses trainer’s mesh if None).

Returns

PyTree with named sharding specifications.

start_evaluation_hook()[source]#

Hook called at the start of evaluation.

Notes

Sets up static metrics and records evaluation start time.

start_training_hook()[source]#

Hook called at the start of training.

Notes

Sets up static metrics and records training start time.

property training_batch_size#

Calculate the effective training batch size.

Returns

The effective batch size including gradient accumulation

Notes

The effective batch size is calculated as: total_batch_size * gradient_accumulation_steps

class easydel.trainers.base_trainer.GenerationResults(generation_results: str | list[str], prompt_ids: jax.Array, prompt_mask: jax.Array, sequences: jax.Array, completion_ids: jax.Array, completion_mask: jax.Array, decoded_prompts: str | list[str], completion_prompts: list[str | list[dict[str, str]]] | None = None)[source]#

Bases: NamedTuple

Results from unified generation containing both text and token representations.

generation_results#

The generation results from engine

Type

str | list[str]

prompt_ids#

Token IDs for the prompt (batch_size, max_seq_len) - left-padded

Type

jax.Array

prompt_mask#

Attention mask for the prompt (batch_size, max_seq_len)

Type

jax.Array

sequences#

Complete generated sequences including prompt (batch_size, max_seq_len + max_new_tokens)

Type

jax.Array

completion_ids#

Token IDs for only the generated completions (batch_size, max_new_tokens) - right-padded

Type

jax.Array

completion_mask#

Attention mask for completions (batch_size, max_new_tokens)

Type

jax.Array

completion_prompts#

Optional prompt objects (text or chat dicts) aligned one-to-one with completions.

Type

list[str | list[dict[str, str]]] | None

completion_ids: Array#

Alias for field number 4

completion_mask: Array#

Alias for field number 5

completion_prompts: list[str | list[dict[str, str]]] | None#

Alias for field number 7

decoded_prompts: str | list[str]#

Alias for field number 6

generation_results: str | list[str]#

Alias for field number 0

prompt_ids: Array#

Alias for field number 1

prompt_mask: Array#

Alias for field number 2

sequences: Array#

Alias for field number 3