easydel.trainers.supervised_fine_tuning_trainer.sft_trainer

easydel.trainers.supervised_fine_tuning_trainer.sft_trainer#

class easydel.trainers.supervised_fine_tuning_trainer.sft_trainer.SFTTrainer(arguments: SFTConfig, processing_class: ProcessingClassType, model: EasyDeLBaseModule | EasyDeLState | None = None, train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None, formatting_func: tp.Callable | None = None, data_collator: DataCollatorForCompletionOnlyLM | None = None)[source]#

Bases: Trainer

Supervised Fine-Tuning trainer for language models.

Implements standard supervised fine-tuning for both base and instruction-tuned models. Supports various data formats including conversational datasets, completion-only training, and packed sequences for efficient training.

Key features: - Automatic dataset formatting and tokenization via lazy transforms - Support for conversational/chat templates - Sequence packing for improved efficiency - Completion-only loss (ignore prompt tokens) - Multi-turn conversation handling

The trainer uses lazy preprocessing transforms that are applied during iteration, providing better performance than eager HF .map() calls.

arguments#

SFTConfig with training hyperparameters

Type

easydel.trainers.training_configurations.TrainingArguments

tokenizer#

Tokenizer for text processing

formatting_func#

Optional function to format examples

Example

>>> config = SFTConfig(
...     per_device_train_batch_size=4,
...     learning_rate=2e-5,
...     packing=True,
...     max_sequence_length=2048
... )
>>> trainer = SFTTrainer(
...     arguments=config,
...     model=model,
...     train_dataset=dataset,
...     processing_class=tokenizer,
...     formatting_func=lambda x: x["text"]  # Optional
... )
>>> trainer.train()

Note

For conversational datasets, the trainer expects either: - A ‘messages’ column with chat format - A custom formatting_func to extract text - A dataset_text_field pointing to the text column