easydel.trainers.utils#

class easydel.trainers.utils.DPODataCollatorWithPadding(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: Optional[bool] = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#

Bases: object

DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.

is_encoder_decoder: Optional[bool] = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
output_arrays_only: bool = True#
pad_token_id: int = 0#
prepadded: bool = True#
class easydel.trainers.utils.DataCollatorForCompletionOnlyLM(processing_class: Union[str, PreTrainedTokenizerBase], response_template: Union[str, List[int]], instruction_template: Optional[Union[str, List[int]]] = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs)[source]#

Bases: object

Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an ‘ignore_index’ when they do not come from the assistant. This ensures that the loss is only calculated on the completion made by the assistant.

jax_call(examples: List[Union[List[int], Any, Dict[str, Any]]]) Dict[str, Any][source]#
jax_mask_tokens(inputs: Any, special_tokens_mask: Optional[Any] = None) Tuple[Any, Any][source]#

Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.

class easydel.trainers.utils.DataCollatorForPreference(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: Optional[bool] = False)[source]#

Bases: object

DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.

is_encoder_decoder: Optional[bool] = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
pad_token_id: int = 0#
class easydel.trainers.utils.JaxDistributedConfig[source]#

Bases: object

From EasyLM Utility class for initializing JAX distributed.

static get_default_config(updates=None)[source]#
classmethod initialize(config=None)[source]#
class easydel.trainers.utils.RewardDataCollatorWithPadding(tokenizer: Any, padding: Union[bool, str] = 'max_length', max_length: Optional[int] = None, truncation_mode: str = 'keep_end')[source]#

Bases: object

Reward DataCollator class that pads the inputs to the maximum length of the batch.

Parameters
  • tokenizer (ProcessingClassType) – The tokenizer used for encoding the data.

  • padding (Union[bool, str, `PaddingStrategy]`, optional, defaults to True) – padding_strategy to pass to the tokenizer.

  • max_length (int or None, optional, defaults to None) – If set will pad the sequence to a maximum provided value.

max_length: Optional[int] = None#
padding: Union[bool, str] = 'max_length'#
tokenizer: Any#
truncation_mode: str = 'keep_end'#
easydel.trainers.utils.add_bos_token_if_needed(bos_token_id: Optional[int], prompt_len_input_ids: int, prompt_tokens: Dict[str, List[int]], chosen_prompt_len_input_ids: int, chosen_tokens: Dict[str, List[int]], rejected_prompt_len_input_ids: int, rejected_tokens: Dict[str, List[int]])[source]#
easydel.trainers.utils.add_eos_token_if_needed(eos_token_id: int, chosen_tokens: Dict[str, List[int]], rejected_tokens: Dict[str, List[int]])[source]#
easydel.trainers.utils.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: Optional[list] = None)[source]#

return a callable function that takes in a “messages” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset

easydel.trainers.utils.create_constant_length_dataset(processing_class, dataset, dataset_text_field: Optional[str] = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[Dict[str, Array]]][source]#

Creates a generator function that yields constant length chunks of tokens from a stream of text files.

Parameters
  • processing_class – The processor used for processing the data.

  • dataset – Dataset with text files.

  • dataset_text_field – Name of the field in the dataset that contains the text.

  • formatting_func – Function that formats the text before tokenization.

  • infinite – If True the iterator is reset after dataset reaches end else stops.

  • seq_length – Length of token sequences to return.

  • num_of_sequences – Number of token sequences to keep in buffer.

  • chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.

  • eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.

  • shuffle – Shuffle the examples before they are returned.

  • append_concat_token – If true, appends eos_token_id at the end of each sample being packed.

  • add_special_tokens – If true, processing_class adds special tokens to each sample being packed.

Returns

A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays

easydel.trainers.utils.create_prompt_creator(processing_class)[source]#
easydel.trainers.utils.first_true_indices(bools, dtype=<class 'jax.numpy.int32'>)[source]#

Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving the position of the first True in each “row”.

easydel.trainers.utils.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: Optional[list] = None) Optional[Callable][source]#
easydel.trainers.utils.instructions_formatting_function(processing_class: AutoTokenizer)[source]#

from TRL return a callable function that takes in an “instructions” dataset and returns a formatted dataset, based on the processing_class apply chat template to the dataset

easydel.trainers.utils.leave_alone_context_manager()[source]#
easydel.trainers.utils.pad(tensors: list[jax.Array], max_lenght: Optional[int], padding_value: int = 0, padding_side: str = 'right') Array[source]#

Pads a list of tensors to the same shape along the first dimension.

easydel.trainers.utils.pad_sequence(sequences, batch_first=False, padding_value=0, max_len: int | None = None)[source]#
easydel.trainers.utils.pad_to_length(tensor: Union[Array, ndarray, bool, number], length: int, pad_value: Union[int, float], axis: int = -1) Union[Array, ndarray, bool, number][source]#
easydel.trainers.utils.shift_and_pad(mask, *tensors)[source]#
easydel.trainers.utils.tolist(x)[source]#

from HF :param x:

Returns: X as tp.List

easydel.trainers.utils.truncate_right(input_ids, stop_token_id, pad_token_id)[source]#

Truncates the input array from the right side after the first occurrence of the stop token.