easydel.trainers.__init__#
- class easydel.trainers.__init__.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainerProtocol- apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to the model.
- configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configures the dataloaders for training and evaluation.
This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.
- Returns
- An object containing the configured dataloaders and the
maximum number of training and evaluation steps.
- Return type
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- configure_model() TrainerConfigureModelOutput[source]#
Configures the model, optimizer, scheduler, and configuration.
This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.
- Returns
An object containing the configured model, optimizer, scheduler, and configuration.
- Return type
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- property evaluation_batch_size#
- get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#
Return the total number of FLOPs for the model.
- initialize_trainer_utils()[source]#
Initializes various utilities used by the trainer.
This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.
- property is_process_zero#
- log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#
Log metrics and update progress bar.
- log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log distribution of weights.
- property mesh#
- property model#
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
hook process to call in start of the step.
- save_information(output_path: Union[str, Path]) None[source]#
Save the generated information to a markdown file.
- Parameters
output_path – Path where the markdown file should be saved
- save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- property training_batch_size#
- class easydel.trainers.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for Direct Preference Optimization (DPO) training.
Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.
- beta#
Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1
- Type
float
- label_smoothing#
Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0
- Type
float
- loss_type#
Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’
- Type
LOSS_FN_VARIENTS
- use_weighting#
Whether to apply example weighting in loss calculation. Default: False
- Type
bool
- label_pad_token_id#
Token ID used for padding labels. Default: -100
- Type
int
- padding_value#
Value used for padding sequences. If None, uses model’s default padding token. Default: None
- Type
int | None
- max_length#
Maximum total sequence length (prompt + completion). Default: 512
- Type
int | None
- max_prompt_length#
Maximum length for prompt sequences. Default: 256
- Type
int | None
- max_completion_length#
Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None
- Type
int | None
- is_encoder_decoder#
Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None
- Type
bool | None
- disable_dropout#
Whether to disable dropout during training for deterministic behavior. Default: True
- Type
bool
- precompute_ref_log_probs#
Whether to precompute reference model log probabilities before training. Default: False
- Type
bool
- dataset_num_proc#
Number of processes for dataset preprocessing. Default: None (sequential processing)
- Type
int | None
- reference_free#
Whether to use reference-free variant of DPO. Default: False
- Type
bool
- force_use_ref_model#
Force use reference model even when reference_free=True. Default: False
- Type
bool
- sync_ref_model#
Whether to periodically sync reference model with training model. Default: False
- Type
bool
- learning_rate#
Optimizer learning rate. Default: 1e-6
- Type
float
- ref_model_mixup_alpha#
Alpha parameter for mixup between policy and reference models. Default: 0.9
- Type
float
- ref_model_sync_steps#
Number of steps between reference model syncs. Default: 64
- Type
int
- rpo_alpha#
Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None
- Type
float | None
- tools#
Additional tools for training process
- Type
list[dict | Callable] | None
Example
>>> config = DPOConfig( ... beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6 ... )
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- force_use_ref_model: bool = False#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- label_smoothing: float = 0.0#
- learning_rate: float = 1e-06#
- loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 512#
- max_prompt_length: Optional[int] = 256#
- model_name: str = 'DPOTrainer'#
- padding_value: Optional[int] = None#
- precompute_ref_log_probs: bool = False#
- ref_model_mixup_alpha: float = 0.9#
- ref_model_sync_steps: int = 64#
- reference_free: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- rpo_alpha: Optional[float] = None#
- sync_ref_model: bool = False#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tools: Optional[List[Union[dict, Callable]]] = None#
- use_weighting: bool = False#
- class easydel.trainers.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#
Bases:
TrainerTrainer for Direct Preference Optimization (DPO).
This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.
- compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#
Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.
- Parameters
state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).
padded_batch (tp.Dict) – The padded batch of data.
- Returns
A tuple containing the log probabilities for the chosen and rejected responses.
- Return type
tuple[tp.Any, tp.Any]
- configure_dataloaders()[source]#
Returns the training dataloader, potentially with precomputed reference log probabilities.
If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.
- Returns
The training dataloader.
- Return type
tensorflow.data.Dataset
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
- static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
Tokenize a row of the dataset.
- Parameters
features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.
processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.
max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.
max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.
add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.
- Returns
Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.
- Return type
dict[str, list[int]]
- class easydel.trainers.__init__.GRPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'GRPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: ~typing.Optional[bool] = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_prompt_length: int = 512, max_completion_length: int = 256, dataset_num_proc: ~typing.Optional[int] = None, beta: float = 0.04, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None, skip_apply_chat_template: bool = False)[source]#
Bases:
TrainingArgumentsConfiguration class for the GRPOTrainer.
- beta: float = 0.04#
- dataset_num_proc: Optional[int] = None#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- learning_rate: float = 1e-06#
- max_completion_length: int = 256#
- max_prompt_length: int = 512#
- model_name: str = 'GRPOTrainer'#
- ref_model_mixup_alpha: float = 0.9#
- ref_model_sync_steps: int = 64#
- remove_unused_columns: Optional[bool] = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- skip_apply_chat_template: bool = False#
- sync_ref_model: bool = False#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tools: Optional[List[Union[dict, Callable]]] = None#
- class easydel.trainers.__init__.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#
Bases:
Trainer- arguments: GRPOConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- class easydel.trainers.__init__.JaxDistributedConfig[source]#
Bases:
objectFrom EasyLM Utility class for initializing JAX distributed.
- class easydel.trainers.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for ORPO training settings.
This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.
- model_name#
The name of the model. Default is “ORPOTrainer”.
- Type
str
- learning_rate#
The learning rate used during training. Default is 1e-6.
- Type
float
- max_length#
The maximum allowed sequence length for the input. Default is 1024.
- Type
Optional[int]
- max_prompt_length#
The maximum allowed length of the prompt portion of the input. Default is 512.
- Type
Optional[int]
- max_completion_length#
The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.
- Type
Optional[int]
- beta#
A hyperparameter beta, with a default value of 0.1.
- Type
float
- disable_dropout#
Flag to disable dropout during training. Default is True.
- Type
bool
- label_pad_token_id#
The token id used for padding labels. Default is -100.
- Type
int
- padding_value#
The value used for padding sequences. Default is None.
- Type
Optional[int]
- generate_during_eval#
Flag indicating whether to generate sequences during evaluation. Default is False.
- Type
bool
- is_encoder_decoder#
Flag to indicate if the model is encoder-decoder. Default is None.
- Type
Optional[bool]
- model_init_kwargs#
Additional keyword arguments for model initialization. Default is None.
- Type
Optional[Dict[str, Any]]
- dataset_num_proc#
Number of processes to use for dataset processing. Default is None.
- Type
Optional[int]
- max_sequence_length#
Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.
- Type
int
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- generate_during_eval: bool = False#
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- learning_rate: float = 1e-06#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 1024#
- max_prompt_length: Optional[int] = 512#
- model_name: str = 'ORPOTrainer'#
- padding_value: Optional[int] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#
Bases:
Trainer- arguments: ORPOConfig#
- build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#
Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.
- Parameters
prompt (str) – The prompt text.
answer (str) – The answer text.
- Returns
A dictionary containing the tokenized prompt and answer, along with attention masks.
- Return type
tp.Dict[str, np.ndarray]
- Raises
ValueError – If there’s a mismatch in token lengths.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#
Tokenizes a single row of data from the ORPO dataset.
This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.
- Parameters
feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.
state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.
- Returns
- A dictionary containing the tokenized prompt, chosen response, and rejected response,
along with attention masks and labels.
- Return type
tp.Dict
- Raises
ValueError – If the input data types are incorrect.
- class easydel.trainers.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 1024, max_training_steps: tp.Optional[int] = None, model_name: str = 'RewardTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, disable_dropout: bool = True, dataset_num_proc: ~typing.Optional[int] = None, center_rewards_coefficient: ~typing.Optional[float] = 0.1)[source]#
Bases:
TrainingArgumentsConfiguration class for the [RewardTrainer].
- Parameters
model_name (str) – The name of the model. Defaults to “RewardTrainer”.
max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.
disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.
dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.
center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.
remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.
- center_rewards_coefficient: Optional[float] = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_sequence_length: Optional[int] = 1024#
- model_name: str = 'RewardTrainer'#
- remove_unused_columns: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#
Bases:
TrainerThis trainer extends the Trainer and provides functionalities.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- class easydel.trainers.__init__.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 2e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'SFTTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, dataset_text_field: ~typing.Optional[str] = None, add_special_tokens: bool = False, packing: bool = False, dataset_num_proc: ~typing.Optional[int] = None, dataset_batch_size: int = 1000, dataset_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, eval_packing: ~typing.Optional[bool] = None, num_of_sequences: int = 1024, chars_per_token: float = 3.6)[source]#
Bases:
TrainingArgumentsConfiguration class for the [SFTTrainer].
- Parameters
model_name (str) – The name of the model. Defaults to “SFTTrainer”.
dataset_text_field (str, optional) – Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.
packing (bool, optional) – Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.
learning_rate (float, optional) – Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.
dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.
dataset_batch_size (int, optional) – Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.
dataset_kwargs (dict[str, Any], optional) – Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.
eval_packing (bool, optional) – Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.
num_of_sequences (int, optional) – Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.
chars_per_token (float, optional) – Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.
- add_special_tokens: bool = False#
- chars_per_token: float = 3.6#
- dataset_batch_size: int = 1000#
- dataset_kwargs: Optional[dict[str, Any]] = None#
- dataset_num_proc: Optional[int] = None#
- dataset_text_field: Optional[str] = None#
- eval_packing: Optional[bool] = None#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- learning_rate: float = 2e-05#
- model_name: str = 'SFTTrainer'#
- num_of_sequences: int = 1024#
- packing: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.__init__.SFTTrainer(arguments: SFTConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, formatting_func: Optional[Callable] = None, data_collator: Optional[DataCollatorForCompletionOnlyLM] = None)[source]#
Bases:
TrainerTrainer class for Supervised Fine-Tuning (SFT) of language models.
This trainer extends the Trainer and provides functionalities specific to supervised fine-tuning tasks.
- class easydel.trainers.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainer- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluates the model using the provided model state.
This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.
- Parameters
model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.
- Yields
Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.
- Raises
AssertionError – If the evaluation dataloader is not set.
- train() TrainerOutput[source]#
Executes the complete training process.
This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.
- Returns
An object containing the final training state, metrics, and any additional outputs.
- Return type
- class easydel.trainers.__init__.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01, weight_distribution_pattern: 'str' = '.*?(layernorm|norm).*?', weight_distribution_log_steps: 'int' = 0)[source]#
Bases:
object- auto_shard_states: bool = True#
- aux_loss_enabled: bool = False#
- backend: Optional[str] = None#
- clip_grad: Optional[float] = None#
- custom_scheduler: Optional[Callable[[int], Any]] = None#
- dataloader_num_workers: Optional[int] = 0#
- dataloader_pin_memory: Optional[bool] = False#
- do_eval: bool = False#
- do_last_save: bool = True#
- do_train: bool = True#
- eval_batch_size: Optional[int] = None#
- evaluation_steps: Optional[int] = None#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- frozen_parameters: Optional[str] = None#
- get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#
Returns the configured optimizer and learning rate scheduler.
- Parameters
steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.
- Returns
A tuple containing the optimizer and scheduler.
- Return type
tuple
- get_path() Path[source]#
Returns the path to the checkpoint directory.
- Returns
The path to the checkpoint directory.
- Return type
Path
- get_streaming_checkpointer()[source]#
Returns the checkpoint manager, responsible for saving model checkpoints.
- Returns
The checkpoint manager.
- Return type
- get_tensorboard()[source]#
Returns the TensorBoard SummaryWriter, used for logging metrics.
- Returns
The TensorBoard SummaryWriter.
- Return type
flax.metrics.tensorboard.SummaryWriter
- get_wandb_init()[source]#
Initializes Weights & Biases for experiment tracking if enabled.
- Returns
The WandB run object if initialized, else None.
- Return type
tp.Optional[wandb.sdk.wandb_run.Run]
- gradient_accumulation_steps: int = 1#
- ids_to_pop_from_dataset: Optional[List[str]]#
- init_tx: bool = True#
- is_fine_tuning: bool = True#
- property is_process_zero#
- jax_distributed_config: Optional[dict] = None#
- learning_rate: float = 5e-05#
- learning_rate_end: Optional[float] = None#
- log_all_workers: bool = False#
- log_grad_norms: bool = True#
- log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#
Logs training metrics to Weights & Biases and/or TensorBoard.
- Parameters
metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.
step (int) – The current training step or iteration.
- log_steps: int = 10#
- loss_config: Optional[LossConfig] = None#
- low_mem_usage: bool = True#
- max_evaluation_steps: Optional[int] = None#
- max_sequence_length: Optional[int] = 4096#
- max_training_steps: Optional[int] = None#
- metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
- model_name: str = 'BaseTrainer'#
- model_parameters: Optional[dict] = None#
- num_train_epochs: int = 10#
- offload_dataset: bool = False#
- property offload_device#
- offload_device_index: int = 0#
- offload_device_type: str = 'cpu'#
- optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
- performance_mode: bool = False#
- process_zero_is_admin: bool = True#
- progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
- pruning_module: Any = None#
- remove_ckpt_after_load: bool = False#
- remove_unused_columns: bool = True#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- report_metrics: bool = True#
- report_steps: int = 5#
- save_directory: str = 'EasyDeL-Checkpoints'#
- save_optimizer_state: bool = True#
- save_steps: Optional[int] = None#
- save_total_limit: Optional[int] = None#
- scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
- shuffle_train_dataset: bool = True#
- sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
- sparsify_module: bool = False#
- state_apply_fn_kwarguments_to_model: Optional[dict] = None#
- step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
- step_start_point: Optional[int] = None#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- total_batch_size: int = 32#
- track_memory: bool = False#
- train_on_inputs: bool = True#
- training_time_limit: Optional[str] = None#
- property training_time_seconds: int#
- truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
- use_data_collactor: bool = True#
- use_wandb: bool = True#
- verbose: bool = True#
- wandb_entity: Optional[str] = None#
- warmup_steps: int = 0#
- weight_decay: float = 0.01#
- weight_distribution_log_steps: int = 0#
- weight_distribution_pattern: str = '.*?(layernorm|norm).*?'#
- easydel.trainers.__init__.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: Optional[list] = None)[source]#
return a callable function that takes in a “messages” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset
- easydel.trainers.__init__.create_constant_length_dataset(processing_class, dataset, dataset_text_field: Optional[str] = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[Dict[str, Array]]][source]#
Creates a generator function that yields constant length chunks of tokens from a stream of text files.
- Parameters
processing_class – The processor used for processing the data.
dataset – Dataset with text files.
dataset_text_field – Name of the field in the dataset that contains the text.
formatting_func – Function that formats the text before tokenization.
infinite – If True the iterator is reset after dataset reaches end else stops.
seq_length – Length of token sequences to return.
num_of_sequences – Number of token sequences to keep in buffer.
chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.
shuffle – Shuffle the examples before they are returned.
append_concat_token – If true, appends eos_token_id at the end of each sample being packed.
add_special_tokens – If true, processing_class adds special tokens to each sample being packed.
- Returns
A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays
- easydel.trainers.__init__.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: Optional[list] = None) Optional[Callable][source]#
- easydel.trainers.__init__.instructions_formatting_function(processing_class: AutoTokenizer)[source]#
from TRL return a callable function that takes in an “instructions” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset
- easydel.trainers.__init__.pack_sequences(dataset: Any, max_length: int = 512, pad_token_id: int = 0, reset_position_ids: bool = False, num_proc: Optional[int] = None)[source]#
Pack sequences together with their attention masks and position IDs
# With continuous position IDs packed_dataset = pack_sequences(
dataset, max_length=512, pad_token_id=0, reset_position_ids=False
)
# With reset position IDs for each sequence packed_dataset = pack_sequences(
dataset, max_length=512, pad_token_id=0, reset_position_ids=True
)
# Example output format for a packed sequence with two sequences: # reset_position_ids=False: {
‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,3,4,5,6,7,0,0]
}
# reset_position_ids=True: {
‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,0,0,1,2,0,0,0]
}
- Parameters
dataset – Dataset containing ‘input_ids’ and ‘attention_mask’
max_length – Maximum length of packed sequence
pad_token_id – Token ID used for padding
reset_position_ids – If True, reset position IDs for each sequence in the pack
- Returns
Dataset with packed sequences, attention masks, and position IDs
- Return type
packed_dataset