easydel.trainers.utils#
- class easydel.trainers.utils.DPODataCollatorWithPadding(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: Optional[bool] = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#
Bases:
objectDPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_encoder_decoder: Optional[bool] = False#
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- output_arrays_only: bool = True#
- pad_token_id: int = 0#
- prepadded: bool = True#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.DataCollatorForCompletionOnlyLM(processing_class: Union[str, PreTrainedTokenizerBase], response_template: Union[str, List[int]], instruction_template: Optional[Union[str, List[int]]] = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs)[source]#
Bases:
objectData collator used for completion tasks. It ensures that all the tokens of the labels are set to an ‘ignore_index’ when they do not come from the assistant. This ensures that the loss is only calculated on the completion made by the assistant.
- class easydel.trainers.utils.DataCollatorForPreference(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: Optional[bool] = False)[source]#
Bases:
objectDPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_encoder_decoder: Optional[bool] = False#
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- pad_token_id: int = 0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.JaxDistributedConfig[source]#
Bases:
objectFrom EasyLM Utility class for initializing JAX distributed.
- class easydel.trainers.utils.RewardDataCollatorWithPadding(tokenizer: Any, padding: Union[bool, str] = 'max_length', max_length: Optional[int] = None, truncation_mode: str = 'keep_end')[source]#
Bases:
objectReward DataCollator class that pads the inputs to the maximum length of the batch.
- Parameters
tokenizer (ProcessingClassType) – The tokenizer used for encoding the data.
padding (Union[bool, str, `PaddingStrategy]`, optional, defaults to True) – padding_strategy to pass to the tokenizer.
max_length (int or None, optional, defaults to None) – If set will pad the sequence to a maximum provided value.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_length: Optional[int] = None#
- padding: Union[bool, str] = 'max_length'#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tokenizer: Any#
- truncation_mode: str = 'keep_end'#
- easydel.trainers.utils.add_bos_token_if_needed(bos_token_id: Optional[int], prompt_len_input_ids: int, prompt_tokens: Dict[str, List[int]], chosen_prompt_len_input_ids: int, chosen_tokens: Dict[str, List[int]], rejected_prompt_len_input_ids: int, rejected_tokens: Dict[str, List[int]])[source]#
- easydel.trainers.utils.add_eos_token_if_needed(eos_token_id: int, chosen_tokens: Dict[str, List[int]], rejected_tokens: Dict[str, List[int]])[source]#
- easydel.trainers.utils.compute_weight_stats(params, repattern: str)[source]#
Compute statistics for model weights in a JIT-compatible way.
- Parameters
params – Model parameters
repattern – parameters to analyze
- Returns
Dictionary of weight statistics
- easydel.trainers.utils.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: Optional[list] = None)[source]#
return a callable function that takes in a “messages” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset
- easydel.trainers.utils.create_constant_length_dataset(processing_class, dataset, dataset_text_field: Optional[str] = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[Dict[str, Array]]][source]#
Creates a generator function that yields constant length chunks of tokens from a stream of text files.
- Parameters
processing_class – The processor used for processing the data.
dataset – Dataset with text files.
dataset_text_field – Name of the field in the dataset that contains the text.
formatting_func – Function that formats the text before tokenization.
infinite – If True the iterator is reset after dataset reaches end else stops.
seq_length – Length of token sequences to return.
num_of_sequences – Number of token sequences to keep in buffer.
chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.
shuffle – Shuffle the examples before they are returned.
append_concat_token – If true, appends eos_token_id at the end of each sample being packed.
add_special_tokens – If true, processing_class adds special tokens to each sample being packed.
- Returns
A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays
- easydel.trainers.utils.first_true_indices(bools, dtype=<class 'jax.numpy.int32'>)[source]#
Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving the position of the first True in each “row”.
- easydel.trainers.utils.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: Optional[list] = None) Optional[Callable][source]#
- easydel.trainers.utils.instructions_formatting_function(processing_class: AutoTokenizer)[source]#
from TRL return a callable function that takes in an “instructions” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset
- easydel.trainers.utils.pad(tensors: list[jax.Array], max_lenght: Optional[int], padding_value: int = 0, padding_side: str = 'right') Array[source]#
Pads a list of tensors to the same shape along the first dimension.
- easydel.trainers.utils.pad_sequence(sequences, batch_first=False, padding_value=0, max_len: int | None = None)[source]#