# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import re
import typing as tp
import warnings
from contextlib import contextmanager
from functools import partial
import chex
import jax
import numpy as np
from jax import numpy as jnp
from ml_collections import ConfigDict
from ml_collections.config_dict import placeholder
from easydel.infra.utils import ProcessingClassType
from easydel.utils import traversals
from easydel.utils import traversals as etr
from easydel.utils.helpers import get_logger
logger = get_logger(__name__)
[docs]class JaxDistributedConfig(object):
"""
From EasyLM
Utility class for initializing JAX distributed.
"""
[docs] @staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.initialize_jax_distributed = False
config.coordinator_address = placeholder(str)
config.num_processes = placeholder(int)
config.process_id = placeholder(int)
config.local_device_ids = placeholder(str)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
[docs] @classmethod
def initialize(cls, config=None):
config = cls.get_default_config(config)
if config.initialize_jax_distributed:
if config.local_device_ids is not None:
local_device_ids = [int(x) for x in config.local_device_ids.split(",")]
else:
local_device_ids = None
jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=config.process_id,
local_device_ids=local_device_ids,
)
# fmt:off
[docs]def create_prompt_creator(processing_class):
def to_role_and_content(field):
return {
"conversation": [
{"role": "user", "content": field["conversation"][0]["input"]},
{"role": "assistant", "content": field["conversation"][0]["output"]},
]
}
def _pc(sample):
return conversations_formatting_function(processing_class, messages_field="conversation")(to_role_and_content(sample))
return _pc
# fmt:on
[docs]def create_constant_length_dataset(
processing_class,
dataset,
dataset_text_field: tp.Optional[str] = None,
formatting_func: tp.Optional[tp.Callable] = None,
infinite: bool = False,
seq_length: int = 1024,
num_of_sequences: int = 1024,
chars_per_token: float = 3.6,
eos_token_id: int = 0,
shuffle: bool = True,
append_concat_token: bool = True,
add_special_tokens: bool = True,
) -> tp.Callable[[], tp.Iterator[tp.Dict[str, jnp.ndarray]]]:
"""
Creates a generator function that yields constant length chunks of tokens from a stream of text files.
Args:
processing_class: The processor used for processing the data.
dataset: Dataset with text files.
dataset_text_field: Name of the field in the dataset that contains the text.
formatting_func: Function that formats the text before tokenization.
infinite: If True the iterator is reset after dataset reaches end else stops.
seq_length: Length of token sequences to return.
num_of_sequences: Number of token sequences to keep in buffer.
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id: Id of the end of sequence token if the passed processing_class does not have an EOS token.
shuffle: Shuffle the examples before they are returned.
append_concat_token: If true, appends eos_token_id at the end of each sample being packed.
add_special_tokens: If true, processing_class adds special tokens to each sample being packed.
Returns:
A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays
"""
if processing_class.eos_token_id is None:
warnings.warn(
"The passed processing_class does not have an EOS token. We will use the passed eos_token_id instead which "
f"corresponds to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id.",
stacklevel=1,
)
concat_token_id = (
processing_class.eos_token_id if processing_class.eos_token_id else eos_token_id
)
max_buffer_size = seq_length * chars_per_token * num_of_sequences
# Input validation and formatting function setup
if dataset_text_field is not None and formatting_func is not None:
warnings.warn(
"Only one of `dataset_text_field` and `formatting_func` should be provided. "
"Ignoring `dataset_text_field` and using `formatting_func`.",
stacklevel=1,
)
if formatting_func is not None:
if formatting_func.__code__.co_argcount > 1:
warnings.warn(
"The passed formatting_func has more than one argument. Usually that function should have a single argument "
"`example` which corresponds to the dictionary returned by each element of the dataset. Make sure you know "
"what you are doing.",
stacklevel=1,
)
elif dataset_text_field is not None:
formatting_func = lambda x: x[dataset_text_field] # noqa
else:
raise ValueError(
"Either `dataset_text_field` or `formatting_func` should be provided."
)
def constant_length_generator() -> tp.Iterator[tp.Dict[str, jnp.ndarray]]:
iterator = iter(dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
# Fill the buffer
while True:
if buffer_len >= max_buffer_size:
break
try:
prompt = formatting_func(next(iterator))
if isinstance(prompt, list):
prompt = "".join(p for p in prompt)
buffer.append(prompt)
buffer_len += len(buffer[-1])
except StopIteration:
if infinite:
iterator = iter(dataset)
warnings.warn(
"The dataset reached end and the iterator is reset to the start.",
stacklevel=1,
)
else:
more_examples = False
break
if shuffle:
random.shuffle(buffer)
# Tokenize all texts in the buffer
tokens = processing_class(
buffer,
add_special_tokens=add_special_tokens,
truncation=False,
)
tokenized_inputs = tokens["input_ids"]
attention_masks = tokens["attention_mask"]
# Concatenate all tokens and attention masks
all_token_ids = []
all_attention_masks = []
for tokenized_input, attention_mask in zip(tokenized_inputs, attention_masks):
if append_concat_token:
tokenized_input = tokenized_input + [concat_token_id]
attention_mask = attention_mask + [1]
all_token_ids.extend(tokenized_input)
all_attention_masks.extend(attention_mask)
# Create fixed-length examples
examples = []
examples_attention_masks = []
for i in range(0, len(all_token_ids), seq_length):
input_ids = all_token_ids[i : i + seq_length]
org_attention_masks = all_attention_masks[i : i + seq_length]
if len(input_ids) == seq_length:
examples.append(input_ids)
examples_attention_masks.append(org_attention_masks)
if shuffle:
# Shuffle examples while keeping pairs together
combined = list(zip(examples, examples_attention_masks))
random.shuffle(combined)
examples, examples_attention_masks = zip(*combined)
# Yield examples
for example, example_attention_mask in zip(examples, examples_attention_masks):
yield {
"input_ids": jnp.asarray(example, dtype="i4"),
"attention_mask": jnp.asarray(example_attention_mask, dtype="i4"),
}
return constant_length_generator
def _collate_batch(
examples, processing_class, pad_to_multiple_of: tp.Optional[int] = None
):
if isinstance(examples[0], (list, tuple)):
examples = [jnp.array(e, dtype=jnp.int64) for e in examples]
length_of_first = len(examples[0])
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
if are_tensors_same_length and (
pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0
):
return jnp.stack(examples, axis=0)
if processing_class._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the processing_class you are using"
f" ({processing_class.__class__.__name__}) does not have a pad token."
)
max_length = max(len(x) for x in examples)
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
result = jnp.full(
shape=(len(examples), max_length),
fill_value=processing_class.pad_token_id,
dtype=examples[0].dtype,
)
for i, example in enumerate(examples):
if processing_class.padding_side == "right":
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result
[docs]def tolist(x):
"""from HF
Args:
x:
Returns: X as tp.List
"""
if isinstance(x, list):
return x
elif hasattr(x, "numpy"):
x = x.numpy()
return x.tolist()
[docs]class DataCollatorForCompletionOnlyLM:
"""Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensures that the loss is only
calculated on the completion made by the assistant.
"""
def __init__(
self,
processing_class: tp.Union[str, "PreTrainedTokenizerBase"], # type:ignore #noqa
response_template: tp.Union[str, tp.List[int]],
instruction_template: tp.Optional[tp.Union[str, tp.List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
from transformers import AutoTokenizer
if isinstance(processing_class, str):
processing_class = AutoTokenizer.from_pretrained(processing_class)
self.processing_class = processing_class
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
self.instruction_token_ids = self.processing_class.encode(
self.instruction_template, add_special_tokens=False
)
else:
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
self.response_token_ids = self.processing_class.encode(
self.response_template, add_special_tokens=False
)
else:
self.response_token_ids = response_template
if (
not mlm
and self.instruction_template
and self.processing_class.pad_token_id == self.processing_class.eos_token_id
):
warnings.warn(
"The pad_token_id and eos_token_id values of this processing_class are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value.",
stacklevel=1,
)
self.ignore_index = ignore_index
def _whole_word_mask(self, input_tokens: tp.List[str], max_predictions=512):
from transformers import BertTokenizer, BertTokenizerFast
if not isinstance(self.processing_class, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
"Please refer to the documentation for more information.",
stacklevel=1,
)
cand_indexes = []
for i, token in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * 0.15))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
if len(covered_indexes) != len(masked_lms):
raise ValueError(
"Length of covered_indexes is not equal to length of masked_lms."
)
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
[docs] def jax_mask_tokens(
self, inputs: tp.Any, special_tokens_mask: tp.Optional[tp.Any] = None
) -> tp.Tuple[tp.Any, tp.Any]:
"""Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""
labels = np.copy(inputs)
probability_matrix = np.full(labels.shape, 0.15)
if special_tokens_mask is None:
special_tokens_mask = [
self.processing_class.get_special_tokens_mask(
val, already_has_special_tokens=True
)
for val in labels.tolist()
]
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
else:
special_tokens_mask = special_tokens_mask.astype(bool)
probability_matrix[special_tokens_mask] = 0
masked_indices = np.random.binomial(
1, probability_matrix, size=probability_matrix.shape
).astype(bool)
labels[~masked_indices] = -100
indices_replaced = (
np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
)
inputs[indices_replaced] = self.processing_class.mask_token_id
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(
low=0,
high=len(self.processing_class),
size=np.count_nonzero(indices_random),
dtype=np.int64,
)
inputs[indices_random] = random_words
return inputs, labels
[docs] def jax_call(
self, examples: tp.List[tp.Union[tp.List[int], tp.Any, tp.Dict[str, tp.Any]]]
) -> tp.Dict[str, tp.Any]:
if isinstance(examples[0], tp.Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = _collate_batch(
input_ids,
self.processing_class,
)
mask_labels = []
for e in examples:
ref_tokens = []
for ida in tolist(e["input_ids"]):
token = self.processing_class._convert_id_to_token(ida)
ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _collate_batch(
mask_labels,
self.processing_class,
)
inputs, labels = self.jax_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def __call__(
self, examples: tp.List[tp.Union[tp.List[int], tp.Any, tp.Dict[str, tp.Any]]]
) -> tp.Dict[str, tp.Any]:
batch = self.jax_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in jnp.where(batch["labels"][i] == self.response_token_ids[0])[0]:
if (
self.response_token_ids
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(
self.response_token_ids
)
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in jnp.where(
batch["labels"][i] == self.response_token_ids[0]
)[0]:
if (
self.response_token_ids
== batch["labels"][i][
assistant_idx : assistant_idx + len(self.response_token_ids)
].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in jnp.where(batch["labels"][i] == human_token_ids[0])[0]:
if (
human_token_ids
== batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist()
):
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(
zip(human_token_ids_idxs, response_token_ids_idxs)
):
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
return batch
[docs]@etr.auto_pytree
class RewardDataCollatorWithPadding:
r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch.
Args:
tokenizer (`ProcessingClassType`):
The tokenizer used for encoding the data.
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
padding_strategy to pass to the tokenizer.
max_length (`int` or `None`, `optional`, defaults to `None`):
If set will pad the sequence to a maximum provided value.
"""
tokenizer: ProcessingClassType
padding: tp.Union[bool, str] = "max_length"
max_length: tp.Optional[int] = None
truncation_mode: str = "keep_end"
def __call__(self, features: list[dict[str, tp.Any]]) -> dict[str, tp.Any]:
features_chosen = []
features_rejected = []
margin = []
has_margin = "margin" in features[0]
for feature in features:
if (
"input_ids_chosen" not in feature
or "input_ids_rejected" not in feature
or "attention_mask_chosen" not in feature
or "attention_mask_rejected" not in feature
):
raise ValueError(
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
)
features_chosen.append(
{
"input_ids": feature["input_ids_chosen"],
"attention_mask": feature["attention_mask_chosen"],
}
)
features_rejected.append(
{
"input_ids": feature["input_ids_rejected"],
"attention_mask": feature["attention_mask_rejected"],
}
)
if has_margin:
margin.append(feature["margin"])
batch_chosen = self.tokenizer.pad(
features_chosen,
padding=self.padding,
max_length=self.max_length,
return_tensors="jax",
)
batch_rejected = self.tokenizer.pad(
features_rejected,
padding=self.padding,
max_length=self.max_length,
return_tensors="jax",
)
batch = {
"input_ids_chosen": batch_chosen["input_ids"],
"attention_mask_chosen": batch_chosen["attention_mask"],
"input_ids_rejected": batch_rejected["input_ids"],
"attention_mask_rejected": batch_rejected["attention_mask"],
}
if has_margin:
margin = jnp.array(margin, dtype="f4")
batch["margin"] = margin
return batch
[docs]@etr.auto_pytree
class DataCollatorForPreference:
r"""DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch."""
max_prompt_length: int
max_completion_length: int
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: tp.Optional[bool] = False
def __call__(self, features: tp.List[tp.Dict[str, tp.Any]]) -> tp.Dict[str, tp.Any]:
prompt_input_ids = [jnp.array(feature["prompt_input_ids"]) for feature in features]
prompt_attention_mask = [jnp.ones_like(input_ids) for input_ids in prompt_input_ids]
chosen_input_ids = [jnp.array(feature["chosen_input_ids"]) for feature in features]
chosen_attention_mask = [jnp.ones_like(input_ids) for input_ids in chosen_input_ids]
rejected_input_ids = [
jnp.array(feature["rejected_input_ids"]) for feature in features
]
rejected_attention_mask = [
jnp.ones_like(input_ids) for input_ids in rejected_input_ids
]
pixel_values = None
pixel_attention_mask = None
if "pixel_values" in features[0]:
pixel_values = [jnp.array(feature["pixel_values"]) for feature in features]
if "pixel_attention_mask" in features[0]:
pixel_attention_mask = [
jnp.array(feature["pixel_attention_mask"]) for feature in features
]
ref_chosen_logps = None
ref_rejected_logps = None
if "ref_chosen_logps" in features[0] and "ref_rejected_logps" in features[0]:
ref_chosen_logps = jnp.array(
[feature["ref_chosen_logps"] for feature in features]
)
ref_rejected_logps = jnp.array(
[feature["ref_rejected_logps"] for feature in features]
)
# Pad sequences
output = {
"prompt_input_ids": pad(
prompt_input_ids,
self.max_prompt_length,
padding_value=self.pad_token_id,
padding_side="left",
),
"prompt_attention_mask": pad(
prompt_attention_mask,
self.max_prompt_length,
padding_value=0,
padding_side="left",
),
"chosen_input_ids": pad(
chosen_input_ids,
self.max_completion_length,
padding_value=self.pad_token_id,
),
"chosen_attention_mask": pad(
chosen_attention_mask,
self.max_completion_length,
padding_value=0,
),
"rejected_input_ids": pad(
rejected_input_ids,
self.max_completion_length,
padding_value=self.pad_token_id,
),
"rejected_attention_mask": pad(
rejected_attention_mask,
self.max_completion_length,
padding_value=0,
),
}
# Add optional outputs
if pixel_values is not None:
output["pixel_values"] = pad(
pixel_values,
self.max_prompt_length,
padding_value=0.0,
)
if pixel_attention_mask is not None:
output["pixel_attention_mask"] = pad(
pixel_attention_mask,
self.max_prompt_length,
padding_value=0,
)
if "image_sizes" in features[0]:
output["image_sizes"] = jnp.array(
[feature["image_sizes"] for feature in features]
)
if ref_chosen_logps is not None and ref_rejected_logps is not None:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
return output
[docs]@etr.auto_pytree
class DPODataCollatorWithPadding:
r"""
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
"""
max_prompt_length: int
max_completion_length: int
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: tp.Optional[bool] = False
output_arrays_only: bool = True
prepadded: bool = True
def __call__(self, features: list[dict[str, tp.Any]]) -> dict[str, tp.Any]:
camax_length = self.max_completion_length + self.max_prompt_length
padded_batch = {}
for k in features[0].keys():
if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
match k.split("_")[0]:
case "rejected":
max_length = self.max_completion_length
case "chosen":
max_length = self.max_completion_length
case "prompt":
max_length = self.max_prompt_length
case _:
max_length = camax_length
if self.is_encoder_decoder:
to_pad = [jnp.array(ex[k], dtype="i4") for ex in features]
if (k.startswith("prompt")) and (k.endswith("input_ids")):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(
to_pad,
batch_first=False,
padding_value=padding_value,
max_len=None if self.prepadded else max_length,
)
else:
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.endswith("_pixel_values"):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
if k in ["prompt_input_ids", "prompt_attention_mask"]:
padding_side = "left"
else:
padding_side = "right"
if k.endswith("_pixel_values"):
dtype = jnp.float32
else:
dtype = jnp.int32
to_pad = [jnp.array(ex[k], dtype=dtype) for ex in features]
padded_batch[k] = pad(
to_pad,
None if self.prepadded else max_length,
padding_value=padding_value,
padding_side=padding_side,
)
elif k.endswith("_logps"):
padded_batch[k] = jnp.array([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
if self.output_arrays_only:
val = padded_batch.get(k)
if hasattr(val, "dtype"):
if val.dtype not in [
jnp.float64,
jnp.float32,
jnp.float16,
jnp.int32,
jnp.int16,
jnp.int8,
]:
padded_batch.pop(k)
else:
padded_batch.pop(k)
return padded_batch
[docs]def shift_and_pad(mask, *tensors):
for i in range(mask.shape[0]):
first_one_idx = np.nonzero(mask[i])[0][0].item()
mask[i] = np.roll(mask[i], shift=-first_one_idx)
for tensor in tensors:
tensor[i] = np.roll(tensor[i], shift=-first_one_idx)
if not tensors:
return mask
else:
return mask, *tensors
[docs]def pad(
tensors: list[jnp.ndarray],
max_lenght: tp.Optional[int],
padding_value: int = 0,
padding_side: str = "right",
) -> jnp.ndarray:
"""
Pads a list of tensors to the same shape along the first dimension.
"""
output_shape = tensors[0].shape[:-1]
current_max = tensors[0].shape[-1]
if max_lenght is None:
max_lenght = current_max
x_lenght = max(current_max, max_lenght)
output_shape += (x_lenght,)
output = jnp.full(
(len(tensors), *output_shape),
padding_value,
dtype=tensors[0].dtype,
)
for i, t in enumerate(tensors):
if padding_side == "left":
seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
elif padding_side == "right":
seq_slice = slice(0, t.shape[0])
else:
raise ValueError("padding_side must be 'left' or 'right'")
slices = (i,) + (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
output = output.at[slices].set(t)
if padding_side == "left":
output = output[..., -max_lenght:]
elif padding_side == "right":
output = output[..., :max_lenght]
else:
raise ValueError("padding_side must be 'left' or 'right'")
return output
[docs]def pad_to_length(
tensor: chex.Array,
length: int,
pad_value: tp.Union[int, float],
axis: int = -1,
) -> chex.Array:
if tensor.shape[axis] >= length:
if tensor.ndim == 2:
tensor = tensor[:, :length]
return tensor
else:
pad_size = list(tensor.shape)
pad_size[axis] = length - tensor.shape[axis]
return jax.numpy.concatenate(
[
tensor,
pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype),
],
axis=axis,
)
[docs]def pad_sequence(
sequences,
batch_first=False,
padding_value=0,
max_len: int | None = None,
):
max_len = max(seq.shape[-1] for seq in sequences) if max_len is None else max_len
padding_value = jnp.array(padding_value).reshape(1)
if batch_first:
padded_seqs = [
(
jnp.concatenate(
[
seq.reshape(1, -1),
jnp.ones((1, max_len - seq.shape[-1])) * padding_value,
],
axis=1,
)
if seq.shape[-1] < max_len
else seq.reshape(1, -1)
)
for seq in sequences
]
else:
padded_seqs = [
(
jnp.concatenate(
[
jnp.ones((1, max_len - seq.shape[-1])) * padding_value,
seq.reshape(1, -1),
],
axis=1,
)
if seq.shape[-1] < max_len
else seq.reshape(1, -1)
)
for seq in sequences
]
return jnp.array(padded_seqs)
[docs]@contextmanager
def leave_alone_context_manager():
# Perform setup actions (none in this case)
yield
[docs]def add_bos_token_if_needed(
bos_token_id: tp.Optional[int],
prompt_len_input_ids: int,
prompt_tokens: tp.Dict[str, tp.List[int]],
chosen_prompt_len_input_ids: int,
chosen_tokens: tp.Dict[str, tp.List[int]],
rejected_prompt_len_input_ids: int,
rejected_tokens: tp.Dict[str, tp.List[int]],
):
if bos_token_id is not None:
if (
prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]
):
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens[
"prompt_input_ids"
]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens[
"prompt_attention_mask"
]
if (
chosen_prompt_len_input_ids == 0
or bos_token_id != chosen_tokens["prompt_input_ids"][0]
):
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens[
"prompt_input_ids"
]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens[
"prompt_attention_mask"
]
if (
rejected_prompt_len_input_ids == 0
or bos_token_id != rejected_tokens["prompt_input_ids"][0]
):
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens[
"prompt_input_ids"
]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens[
"prompt_attention_mask"
]
return prompt_tokens, chosen_tokens, rejected_tokens
[docs]def add_eos_token_if_needed(
eos_token_id: int,
chosen_tokens: tp.Dict[str, tp.List[int]],
rejected_tokens: tp.Dict[str, tp.List[int]],
):
if (
len(chosen_tokens["input_ids"]) == 0
or eos_token_id != chosen_tokens["input_ids"][-1]
):
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if (
len(rejected_tokens["input_ids"]) == 0
or eos_token_id != rejected_tokens["input_ids"][-1]
):
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
return chosen_tokens, rejected_tokens
[docs]def first_true_indices(bools, dtype=jnp.int32):
"""
Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving
the position of the first True in each "row".
"""
row_len = bools.shape[-1]
zero_or_index = row_len * (~bools).astype(dtype) + jnp.arange(row_len, dtype=dtype)
return jnp.min(zero_or_index, axis=-1)
[docs]def truncate_right(input_ids, stop_token_id, pad_token_id):
"""
Truncates the input array from the right side after the first occurrence of the stop token.
"""
trunc_idxs = first_true_indices(input_ids == stop_token_id).reshape((-1, 1))
idxs = jnp.arange(input_ids.shape[1]).reshape((1, -1))
output_ids = jnp.where(idxs > trunc_idxs, pad_token_id, input_ids)
mask = jnp.where(idxs > trunc_idxs, 0, 1)
return output_ids, mask
[docs]@partial(jax.jit, static_argnums=(1,))
def compute_weight_stats(params, repattern: str):
"""Compute statistics for model weights in a JIT-compatible way.
Args:
params: Model parameters
repattern: parameters to analyze
Returns:
Dictionary of weight statistics
"""
stats = {}
for path, weight in traversals.flatten_dict(params).items():
weight = weight.value
pattern_search = ".".join([str(p) for p in path])
path = "/".join([str(p) for p in path])
if bool(re.match(repattern, pattern_search)):
stats[f"{path}/values"] = weight.flatten()
stats[f"{path}/mean"] = jnp.mean(weight)
stats[f"{path}/std"] = jnp.std(weight)
stats[f"{path}/min"] = jnp.min(weight)
stats[f"{path}/max"] = jnp.max(weight)
return stats