easydel.trainers.__init__

Contents

easydel.trainers.__init__#

class easydel.trainers.__init__.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainerProtocol

apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to the model.

calculate_number_total_flops(params, is_training=True)[source]#
compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.

Returns

An object containing the configured dataloaders and the

maximum number of training and evaluation steps.

Return type

TrainerConfigureDataloaderOutput

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.

Returns

An object containing the configured model, optimizer, scheduler, and configuration.

Return type

TrainerConfigureModelOutput

static count_model_parameters(prm)[source]#

Prints the number of model parameters in billions.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

property evaluation_batch_size#
static finish()[source]#

Finalize the training process.

get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#

Return the total number of FLOPs for the model.

initialize_trainer_utils()[source]#

Initializes various utilities used by the trainer.

This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.

property is_process_zero#
log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#

Log metrics and update progress bar.

property mesh#
property model#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

save_information(output_path: Union[str, Path]) None[source]#

Save the generated information to a markdown file.

Parameters

output_path – Path where the markdown file should be saved

save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

start_evaluation_hook()[source]#

Hook to run before evaluation starts.

start_training_hook()[source]#

Hook to run before training starts.

property training_batch_size#
class easydel.trainers.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#

Bases: TrainingArguments

Configuration class for Direct Preference Optimization (DPO) training.

Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.

beta#

Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1

Type

float

label_smoothing#

Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0

Type

float

loss_type#

Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’

Type

LOSS_FN_VARIENTS

use_weighting#

Whether to apply example weighting in loss calculation. Default: False

Type

bool

label_pad_token_id#

Token ID used for padding labels. Default: -100

Type

int

padding_value#

Value used for padding sequences. If None, uses model’s default padding token. Default: None

Type

int | None

max_length#

Maximum total sequence length (prompt + completion). Default: 512

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 256

Type

int | None

max_completion_length#

Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None

Type

int | None

is_encoder_decoder#

Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None

Type

bool | None

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

precompute_ref_log_probs#

Whether to precompute reference model log probabilities before training. Default: False

Type

bool

dataset_num_proc#

Number of processes for dataset preprocessing. Default: None (sequential processing)

Type

int | None

reference_free#

Whether to use reference-free variant of DPO. Default: False

Type

bool

force_use_ref_model#

Force use reference model even when reference_free=True. Default: False

Type

bool

sync_ref_model#

Whether to periodically sync reference model with training model. Default: False

Type

bool

learning_rate#

Optimizer learning rate. Default: 1e-6

Type

float

ref_model_mixup_alpha#

Alpha parameter for mixup between policy and reference models. Default: 0.9

Type

float

ref_model_sync_steps#

Number of steps between reference model syncs. Default: 64

Type

int

rpo_alpha#

Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None

Type

float | None

tools#

Additional tools for training process

Type

list[dict | Callable] | None

Example

>>> config = DPOConfig(
...   beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6
... )
beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
force_use_ref_model: bool = False#
is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
label_smoothing: float = 0.0#
learning_rate: float = 1e-06#
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 512#
max_prompt_length: Optional[int] = 256#
model_name: str = 'DPOTrainer'#
padding_value: Optional[int] = None#
precompute_ref_log_probs: bool = False#
ref_model_mixup_alpha: float = 0.9#
ref_model_sync_steps: int = 64#
reference_free: bool = False#
rpo_alpha: Optional[float] = None#
sync_ref_model: bool = False#
tools: Optional[List[Union[dict, Callable]]] = None#
use_weighting: bool = False#
class easydel.trainers.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#

Bases: Trainer

Trainer for Direct Preference Optimization (DPO).

This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.

arguments: DPOConfig#
compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Parameters
  • state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).

  • padded_batch (tp.Dict) – The padded batch of data.

Returns

A tuple containing the log probabilities for the chosen and rejected responses.

Return type

tuple[tp.Any, tp.Any]

configure_dataloaders()[source]#

Returns the training dataloader, potentially with precomputed reference log probabilities.

If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.

Returns

The training dataloader.

Return type

tensorflow.data.Dataset

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#

Tokenize a row of the dataset.

Parameters
  • features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.

  • processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.

  • max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.

  • max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.

  • add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

Returns

Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.

Return type

dict[str, list[int]]

class easydel.trainers.__init__.GRPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'GRPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: ~typing.Optional[bool] = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, max_prompt_length: int = 512, max_completion_length: int = 256, dataset_num_proc: ~typing.Optional[int] = None, beta: float = 0.04, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None, skip_apply_chat_template: bool = False)[source]#

Bases: TrainingArguments

Configuration class for the GRPOTrainer.

beta: float = 0.04#
dataset_num_proc: Optional[int] = None#
learning_rate: float = 1e-06#
max_completion_length: int = 256#
max_prompt_length: int = 512#
model_name: str = 'GRPOTrainer'#
ref_model_mixup_alpha: float = 0.9#
ref_model_sync_steps: int = 64#
remove_unused_columns: Optional[bool] = False#
skip_apply_chat_template: bool = False#
sync_ref_model: bool = False#
tools: Optional[List[Union[dict, Callable]]] = None#
class easydel.trainers.__init__.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#

Bases: Trainer

arguments: GRPOConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

class easydel.trainers.__init__.JaxDistributedConfig[source]#

Bases: object

From EasyLM Utility class for initializing JAX distributed.

static get_default_config(updates=None)[source]#
classmethod initialize(config=None)[source]#
class easydel.trainers.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#

Bases: TrainingArguments

Configuration class for ORPO training settings.

This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.

model_name#

The name of the model. Default is “ORPOTrainer”.

Type

str

learning_rate#

The learning rate used during training. Default is 1e-6.

Type

float

max_length#

The maximum allowed sequence length for the input. Default is 1024.

Type

Optional[int]

max_prompt_length#

The maximum allowed length of the prompt portion of the input. Default is 512.

Type

Optional[int]

max_completion_length#

The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.

Type

Optional[int]

beta#

A hyperparameter beta, with a default value of 0.1.

Type

float

disable_dropout#

Flag to disable dropout during training. Default is True.

Type

bool

label_pad_token_id#

The token id used for padding labels. Default is -100.

Type

int

padding_value#

The value used for padding sequences. Default is None.

Type

Optional[int]

generate_during_eval#

Flag indicating whether to generate sequences during evaluation. Default is False.

Type

bool

is_encoder_decoder#

Flag to indicate if the model is encoder-decoder. Default is None.

Type

Optional[bool]

model_init_kwargs#

Additional keyword arguments for model initialization. Default is None.

Type

Optional[Dict[str, Any]]

dataset_num_proc#

Number of processes to use for dataset processing. Default is None.

Type

Optional[int]

max_sequence_length#

Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.

Type

int

beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
generate_during_eval: bool = False#
is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
learning_rate: float = 1e-06#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 1024#
max_prompt_length: Optional[int] = 512#
model_name: str = 'ORPOTrainer'#
padding_value: Optional[int] = None#
class easydel.trainers.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#

Bases: Trainer

arguments: ORPOConfig#
build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#

Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.

Parameters
  • prompt (str) – The prompt text.

  • answer (str) – The answer text.

Returns

A dictionary containing the tokenized prompt and answer, along with attention masks.

Return type

tp.Dict[str, np.ndarray]

Raises

ValueError – If there’s a mismatch in token lengths.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#

Tokenizes a single row of data from the ORPO dataset.

This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.

Parameters
  • feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.

  • state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.

Returns

A dictionary containing the tokenized prompt, chosen response, and rejected response,

along with attention masks and labels.

Return type

tp.Dict

Raises

ValueError – If the input data types are incorrect.

class easydel.trainers.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 1024, max_training_steps: tp.Optional[int] = None, model_name: str = 'RewardTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, disable_dropout: bool = True, dataset_num_proc: ~typing.Optional[int] = None, center_rewards_coefficient: ~typing.Optional[float] = 0.1)[source]#

Bases: TrainingArguments

Configuration class for the [RewardTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “RewardTrainer”.

  • max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.

  • disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.

  • center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.

  • remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.

center_rewards_coefficient: Optional[float] = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
max_sequence_length: Optional[int] = 1024#
model_name: str = 'RewardTrainer'#
remove_unused_columns: bool = False#
class easydel.trainers.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#

Bases: Trainer

This trainer extends the Trainer and provides functionalities.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

class easydel.trainers.__init__.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 2e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'SFTTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, dataset_text_field: ~typing.Optional[str] = None, add_special_tokens: bool = False, packing: bool = False, dataset_num_proc: ~typing.Optional[int] = None, dataset_batch_size: int = 1000, dataset_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, eval_packing: ~typing.Optional[bool] = None, num_of_sequences: int = 1024, chars_per_token: float = 3.6)[source]#

Bases: TrainingArguments

Configuration class for the [SFTTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “SFTTrainer”.

  • dataset_text_field (str, optional) – Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.

  • packing (bool, optional) – Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.

  • learning_rate (float, optional) – Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.

  • dataset_batch_size (int, optional) – Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.

  • dataset_kwargs (dict[str, Any], optional) – Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.

  • eval_packing (bool, optional) – Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.

  • num_of_sequences (int, optional) – Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.

  • chars_per_token (float, optional) – Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.

add_special_tokens: bool = False#
chars_per_token: float = 3.6#
dataset_batch_size: int = 1000#
dataset_kwargs: Optional[dict[str, Any]] = None#
dataset_num_proc: Optional[int] = None#
dataset_text_field: Optional[str] = None#
eval_packing: Optional[bool] = None#
learning_rate: float = 2e-05#
model_name: str = 'SFTTrainer'#
num_of_sequences: int = 1024#
packing: bool = False#
class easydel.trainers.__init__.SFTTrainer(arguments: SFTConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, formatting_func: Optional[Callable] = None, data_collator: Optional[DataCollatorForCompletionOnlyLM] = None)[source]#

Bases: Trainer

Trainer class for Supervised Fine-Tuning (SFT) of language models.

This trainer extends the Trainer and provides functionalities specific to supervised fine-tuning tasks.

class easydel.trainers.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainer

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluates the model using the provided model state.

This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.

Parameters

model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.

Yields

Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.

Raises

AssertionError – If the evaluation dataloader is not set.

train() TrainerOutput[source]#

Executes the complete training process.

This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.

Returns

An object containing the final training state, metrics, and any additional outputs.

Return type

TrainerOutput

class easydel.trainers.__init__.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01)[source]#

Bases: object

auto_shard_states: bool = True#
aux_loss_enabled: bool = False#
backend: Optional[str] = None#
clip_grad: Optional[float] = None#
custom_scheduler: Optional[Callable[[int], Any]] = None#
dataloader_num_workers: Optional[int] = 0#
dataloader_pin_memory: Optional[bool] = False#
do_eval: bool = False#
do_last_save: bool = True#
do_train: bool = True#
ensure_checkpoint_path()[source]#

Creates the checkpoint directory if it doesn’t exist.

ensure_training_time_limit(time_passed)[source]#
eval_batch_size: Optional[int] = None#
evaluation_steps: Optional[int] = None#
extra_optimizer_kwargs: dict#
classmethod from_dict(config: Dict[str, Any]) TrainingArguments[source]#

Creates a TrainingArguments instance from a dictionary.

Parameters

config (tp.Dict[str, tp.Any]) – The configuration dictionary.

Returns

A TrainingArguments object initialized with values from the dictionary.

Return type

TrainingArguments

frozen_parameters: Optional[str] = None#
get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#

Returns the configured optimizer and learning rate scheduler.

Parameters

steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.

Returns

A tuple containing the optimizer and scheduler.

Return type

tuple

get_path() Path[source]#

Returns the path to the checkpoint directory.

Returns

The path to the checkpoint directory.

Return type

Path

get_streaming_checkpointer()[source]#

Returns the checkpoint manager, responsible for saving model checkpoints.

Returns

The checkpoint manager.

Return type

CheckpointManager

get_tensorboard()[source]#

Returns the TensorBoard SummaryWriter, used for logging metrics.

Returns

The TensorBoard SummaryWriter.

Return type

flax.metrics.tensorboard.SummaryWriter

get_wandb_init()[source]#

Initializes Weights & Biases for experiment tracking if enabled.

Returns

The WandB run object if initialized, else None.

Return type

tp.Optional[wandb.sdk.wandb_run.Run]

gradient_accumulation_steps: int = 1#
ids_to_pop_from_dataset: Optional[List[str]]#
init_tx: bool = True#
is_fine_tuning: bool = True#
property is_process_zero#
jax_distributed_config: Optional[dict] = None#
learning_rate: float = 5e-05#
learning_rate_end: Optional[float] = None#
log_all_workers: bool = False#
log_grad_norms: bool = True#
log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#

Logs training metrics to Weights & Biases and/or TensorBoard.

Parameters
  • metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.

  • step (int) – The current training step or iteration.

log_steps: int = 10#
loss_config: Optional[LossConfig] = None#
low_mem_usage: bool = True#
max_evaluation_steps: Optional[int] = None#
max_sequence_length: Optional[int] = 4096#
max_training_steps: Optional[int] = None#
metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
model_name: str = 'BaseTrainer'#
model_parameters: Optional[dict] = None#
num_train_epochs: int = 10#
offload_dataset: bool = False#
property offload_device#
offload_device_index: int = 0#
offload_device_type: str = 'cpu'#
optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
performance_mode: bool = False#
process_zero_is_admin: bool = True#
progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
pruning_module: Any = None#
remove_ckpt_after_load: bool = False#
remove_unused_columns: bool = True#
report_metrics: bool = True#
report_steps: int = 5#
save_directory: str = 'EasyDeL-Checkpoints'#
save_optimizer_state: bool = True#
save_steps: Optional[int] = None#
save_total_limit: Optional[int] = None#
scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
shuffle_train_dataset: bool = True#
sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
sparsify_module: bool = False#
state_apply_fn_kwarguments_to_model: Optional[dict] = None#
step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
step_start_point: Optional[int] = None#
to_dict() Dict[str, Any][source]#

Converts the TrainingArguments object into a dictionary.

Returns

A dictionary representation of the TrainingArguments.

Return type

tp.Dict[str, tp.Any]

total_batch_size: int = 32#
track_memory: bool = False#
train_on_inputs: bool = True#
training_time_limit: Optional[str] = None#
property training_time_seconds: int#
truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
tx_mu_dtype: Optional[dtype] = None#
use_data_collactor: bool = True#
use_wandb: bool = True#
verbose: bool = True#
wandb_entity: Optional[str] = None#
warmup_steps: int = 0#
weight_decay: float = 0.01#
easydel.trainers.__init__.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: Optional[list] = None)[source]#

return a callable function that takes in a “messages” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset

easydel.trainers.__init__.create_constant_length_dataset(processing_class, dataset, dataset_text_field: Optional[str] = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[Dict[str, Array]]][source]#

Creates a generator function that yields constant length chunks of tokens from a stream of text files.

Parameters
  • processing_class – The processor used for processing the data.

  • dataset – Dataset with text files.

  • dataset_text_field – Name of the field in the dataset that contains the text.

  • formatting_func – Function that formats the text before tokenization.

  • infinite – If True the iterator is reset after dataset reaches end else stops.

  • seq_length – Length of token sequences to return.

  • num_of_sequences – Number of token sequences to keep in buffer.

  • chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.

  • eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.

  • shuffle – Shuffle the examples before they are returned.

  • append_concat_token – If true, appends eos_token_id at the end of each sample being packed.

  • add_special_tokens – If true, processing_class adds special tokens to each sample being packed.

Returns

A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays

easydel.trainers.__init__.create_prompt_creator(processing_class)[source]#
easydel.trainers.__init__.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: Optional[list] = None) Optional[Callable][source]#
easydel.trainers.__init__.instructions_formatting_function(processing_class: AutoTokenizer)[source]#

from TRL return a callable function that takes in an “instructions” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset

easydel.trainers.__init__.pack_sequences(dataset: Any, max_length: int = 512, pad_token_id: int = 0, reset_position_ids: bool = False, num_proc: Optional[int] = None)[source]#

Pack sequences together with their attention masks and position IDs

# With continuous position IDs packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=False

)

# With reset position IDs for each sequence packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=True

)

# Example output format for a packed sequence with two sequences: # reset_position_ids=False: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,3,4,5,6,7,0,0]

}

# reset_position_ids=True: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,0,0,1,2,0,0,0]

}

Parameters
  • dataset – Dataset containing ‘input_ids’ and ‘attention_mask’

  • max_length – Maximum length of packed sequence

  • pad_token_id – Token ID used for padding

  • reset_position_ids – If True, reset position IDs for each sequence in the pack

Returns

Dataset with packed sequences, attention masks, and position IDs

Return type

packed_dataset