easydel.trainers.supervised_fine_tuning_trainer.sft_config#

class easydel.trainers.supervised_fine_tuning_trainer.sft_config.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for the [SFTTrainer].

Parameters
  • model_name (str) โ€“ The name of the model. Defaults to โ€œSFTTrainerโ€.

  • dataset_text_field (str, optional) โ€“ Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.

  • packing (bool, optional) โ€“ Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.

  • learning_rate (float, optional) โ€“ Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.

  • dataset_num_proc (int, optional) โ€“ Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.

  • dataset_batch_size (int, optional) โ€“ Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.

  • dataset_kwargs (dict[str, Any], optional) โ€“ Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.

  • eval_packing (bool, optional) โ€“ Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.

  • num_of_sequences (int, optional) โ€“ Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.

  • chars_per_token (float, optional) โ€“ Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.

add_special_tokens: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to add special tokens.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
chars_per_token: float = Field(name=None,type=None,default=3.6,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of characters per token to use for the ConstantLengthDataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_batch_size: int = Field(name=None,type=None,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of examples to tokenize per batch.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_kwargs: Optional[dict[str, Any]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dictionary of optional keyword arguments to pass when creating datasets.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes to use for processing the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_text_field: Optional[str] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name of the text field of the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
eval_packing: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to pack the eval dataset. If None, uses the same value as packing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
extra_optimizer_kwargs: dict#
ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
learning_rate: float = Field(name=None,type=None,default=2e-05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial learning rate for the AdamW optimizer.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='SFTTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
num_of_sequences: int = Field(name=None,type=None,default=1024,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of sequences to use for the ConstantLengthDataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
packing: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Controls whether the sequences of the dataset are packed.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#