Source code for easydel.trainers.supervised_fine_tuning_trainer.sft_trainer

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
import warnings

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.base_state import EasyDeLState
from easydel.infra.utils import ProcessingClassType
from easydel.utils.helpers import get_logger

from ..trainer import Trainer
from ..utils import (
	DataCollatorForCompletionOnlyLM,
	create_constant_length_dataset,
	get_formatting_func_from_dataset,
)
from .sft_config import SFTConfig

if tp.TYPE_CHECKING:
	from datasets import Dataset
else:
	Dataset = tp.Any
logger = get_logger(__name__)


[docs]class SFTTrainer(Trainer): """ Trainer class for Supervised Fine-Tuning (SFT) of language models. This trainer extends the `Trainer` and provides functionalities specific to supervised fine-tuning tasks. """ def __init__( self, arguments: SFTConfig, processing_class: ProcessingClassType, model: tp.Optional[tp.Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: tp.Optional[Dataset] = None, eval_dataset: tp.Optional[tp.Union[Dataset, tp.Dict[str, Dataset]]] = None, formatting_func: tp.Optional[tp.Callable] = None, data_collator: tp.Optional[DataCollatorForCompletionOnlyLM] = None, ): if getattr(processing_class, "pad_token", None) is None: processing_class.pad_token = processing_class.eos_token assert isinstance(arguments, SFTConfig), "passed argument must be a `SFTConfig`." if formatting_func is None and arguments.dataset_text_field is None: formatting_func = get_formatting_func_from_dataset( train_dataset, processing_class, ) if not arguments.packing: if data_collator: raise ValueError( "You passed `packing=False` to the SFTTrainer, but you didn't pass a " "`dataset_text_field` or `formatting_func` argument." ) self.dataset_num_proc = arguments.dataset_num_proc self.dataset_batch_size = arguments.dataset_batch_size self.arguments = arguments if arguments.dataset_kwargs is None: arguments.dataset_kwargs = {} if train_dataset is not None: train_dataset = self._prepare_dataset( train_dataset, processing_class, arguments.packing, arguments.dataset_text_field, arguments.max_sequence_length, formatting_func, arguments.num_of_sequences, arguments.chars_per_token, remove_unused_columns=arguments.remove_unused_columns, add_special_tokens=arguments.add_special_tokens, **arguments.dataset_kwargs, ) if eval_dataset is not None: _multiple = isinstance(eval_dataset, dict) _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} eval_packing = ( arguments.packing if arguments.eval_packing is None else arguments.eval_packing ) for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): _eval_datasets[_eval_dataset_name] = self._prepare_dataset( _eval_dataset, processing_class, eval_packing, arguments.dataset_text_field, arguments.max_sequence_length, formatting_func, arguments.num_of_sequences, arguments.chars_per_token, remove_unused_columns=arguments.remove_unused_columns, add_special_tokens=arguments.add_special_tokens, **arguments.dataset_kwargs, ) if not _multiple: eval_dataset = _eval_datasets["singleton"] if ( processing_class.padding_side is not None and processing_class.padding_side != "left" ): warnings.warn( "You passed a processing_class with `padding_side` not equal to `left` to the SFTTrainer. This might lead " "to some unexpected behaviour due to overflow issues when training a model in half-precision. " "You might consider adding `processing_class.padding_side = 'left'` to your code.", stacklevel=1, ) if not isinstance(model, EasyDeLState): model = model.to_state() super().__init__( arguments=arguments, dataset_train=train_dataset, dataset_eval=eval_dataset, model_state=model, data_collator=data_collator, ) def _prepare_dataset( self, dataset, processing_class, packing, dataset_text_field, max_sequence_length, formatting_func, num_of_sequences, chars_per_token, remove_unused_columns=True, append_concat_token=True, add_special_tokens=True, ): """ Prepares the dataset for training by applying tokenization and packing (if enabled). Args: dataset (Dataset): The dataset to prepare. processing_class (ProcessingClassType): The processing_class to use. packing (bool): Whether to pack multiple sequences into a single sample. dataset_text_field (str): The name of the text field in the dataset. max_sequence_length (int): The maximum sequence length. formatting_func (tp.Callable): A formatting function to apply to each sample. num_of_sequences (int): Number of sequences to pack in each sample (if packing is enabled). chars_per_token (float): Average number of characters per token. remove_unused_columns (bool, optional): Whether to remove unused columns. Defaults to True. append_concat_token (bool, optional): Whether to append a concat token for packing. Defaults to True. add_special_tokens (bool, optional): Whether to add special tokens during tokenization. Defaults to True. Returns: Dataset: The processed dataset ready for training. Raises: ValueError: If the dataset is None or if packing is enabled without a `dataset_text_field` or `formatting_func`. """ if dataset is None: raise ValueError("The dataset should not be None") if not packing: return self._prepare_non_packed_dataloader( processing_class, dataset, dataset_text_field, max_sequence_length, formatting_func, add_special_tokens, remove_unused_columns, ) else: return self._prepare_packed_dataloader( processing_class, dataset, dataset_text_field, max_sequence_length, num_of_sequences, chars_per_token, formatting_func, append_concat_token, add_special_tokens, ) def _prepare_non_packed_dataloader( self, processing_class: ProcessingClassType, dataset, dataset_text_field, max_sequence_length, formatting_func=None, add_special_tokens=True, remove_unused_columns=True, ): """ Prepares a non-packed dataloader from the given dataset. This method tokenizes the text data in the dataset, truncates or pads sequences to a fixed length, and removes unused columns as specified. It's suitable for datasets where each sample represents a single sequence. Args: processing_class: The processing_class to use for text encoding. dataset (Dataset): The dataset to prepare. dataset_text_field (str): The name of the text field in the dataset. max_sequence_length (int): The maximum sequence length. formatting_func (tp.Callable, optional): A formatting function to apply to each sample before tokenization. Defaults to None. add_special_tokens (bool, optional): Whether to add special tokens during tokenization. Defaults to True. remove_unused_columns (bool, optional): Whether to remove unused columns from the dataset. Defaults to True. Returns: Dataset: The processed dataset ready for training. """ from datasets import Dataset def tokenize(element): inputs = ( element[dataset_text_field] if formatting_func is None else formatting_func(element) ) outputs = processing_class( inputs, add_special_tokens=add_special_tokens, truncation=True, padding="max_length", max_length=max_sequence_length, return_overflowing_tokens=False, return_attention_mask=True, return_length=False, ) if formatting_func is not None and not isinstance(formatting_func(element), list): raise ValueError( "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." ) return { "input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], } signature_columns = ["input_ids", "labels", "attention_mask"] if dataset.column_names is not None: extra_columns = list(set(dataset.column_names) - set(signature_columns)) else: extra_columns = [] if not remove_unused_columns and len(extra_columns) > 0: warnings.warn( "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with " "the default collator and yield to errors. If you want to inspect dataset other columns (in this " f"case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the " "default collator and create your own data collator in order to inspect the unused dataset columns.", UserWarning, stacklevel=1, ) map_kwargs = { "batched": True, "remove_columns": dataset.column_names if remove_unused_columns else None, "batch_size": self.dataset_batch_size, } if isinstance(dataset, Dataset): map_kwargs["num_proc"] = self.dataset_num_proc tokenized_dataset = dataset.map(tokenize, **map_kwargs) return tokenized_dataset @staticmethod def _prepare_packed_dataloader( processing_class, dataset, dataset_text_field, max_sequence_length, num_of_sequences, chars_per_token, formatting_func=None, append_concat_token=True, add_special_tokens=True, ): """ Prepares a packed dataloader from the given dataset. This method is designed for efficient training of language models by packing multiple sequences from the dataset into a single sample. This can be particularly beneficial for handling long sequences and optimizing GPU/TPU utilization. Args: processing_class: The processing_class used for text encoding. dataset (Dataset): The dataset to prepare. dataset_text_field (str): The name of the text field in the dataset. max_sequence_length (int): The maximum length of each packed sequence. num_of_sequences (int): The number of sequences to pack into a single sample. chars_per_token (float): The average number of characters per token, used for estimating the number of tokens in a text sequence. formatting_func (tp.Callable, optional): A function to format each sample from the dataset before packing. It should take a sample as input and return a dictionary with a "text" key containing the processed text. Defaults to None. append_concat_token (bool, optional): Whether to append a special concatenation token between packed sequences. Defaults to True. add_special_tokens (bool, optional): Whether to add special tokens (like BOS, EOS) during tokenization. Defaults to True. Returns: Dataset: The processed dataset with packed sequences. Raises: ValueError: If both `dataset_text_field` and `formatting_func` are None, or if there's an error during dataset packing. """ if dataset_text_field is not None or formatting_func is not None: if processing_class is None: raise ValueError( "You need to pass a processing_class when using `dataset_text_field` with `SFTTrainer`." ) constant_length_iterator = create_constant_length_dataset( processing_class=processing_class, dataset=dataset, dataset_text_field=dataset_text_field, formatting_func=formatting_func, seq_length=max_sequence_length, infinite=False, num_of_sequences=num_of_sequences, chars_per_token=chars_per_token, eos_token_id=processing_class.eos_token_id, append_concat_token=append_concat_token, add_special_tokens=add_special_tokens, ) def data_generator(inner_constant_length_iterator): for d in inner_constant_length_iterator(): yield d # Import Only and Only when needed, don't dst the runtime. try: from datasets import Dataset from datasets.arrow_writer import SchemaInferenceError from datasets.builder import DatasetGenerationError except ImportError as exc: raise ImportError( "Could not import `datasets` from Hugging Face. Make sure to install the library using `pip install datasets`." ) from exc try: packed_dataset = Dataset.from_generator( data_generator, gen_kwargs={"inner_constant_length_iterator": constant_length_iterator}, ) except (DatasetGenerationError, SchemaInferenceError) as exc: raise ValueError( "Error occurred while packing the dataset. " "Make sure that your dataset has enough samples to at least yield one packed sequence.\n" "External Information : {}".format(exc) ) from exc return packed_dataset else: raise ValueError( "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want " "to use the `ConstantLengthDataset`." )