# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import re
import typing as tp
import warnings
from contextlib import contextmanager
from functools import partial
from eformer.pytree import auto_pytree
import chex
import jax
import numpy as np
from jax import numpy as jnp
from ml_collections import ConfigDict
from ml_collections.config_dict import placeholder
from easydel.infra.utils import ProcessingClassType
from easydel.utils import traversals
from easydel.utils.helpers import get_logger
logger = get_logger(__name__)
[docs]class JaxDistributedConfig(object):
"""
From EasyLM
Utility class for initializing JAX distributed.
"""
[docs] @staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.initialize_jax_distributed = False
config.coordinator_address = placeholder(str)
config.num_processes = placeholder(int)
config.process_id = placeholder(int)
config.local_device_ids = placeholder(str)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
[docs] @classmethod
def initialize(cls, config=None):
config = cls.get_default_config(config)
if config.initialize_jax_distributed:
if config.local_device_ids is not None:
local_device_ids = [int(x) for x in config.local_device_ids.split(",")]
else:
local_device_ids = None
jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=config.process_id,
local_device_ids=local_device_ids,
)
# fmt:off
[docs]def create_prompt_creator(processing_class):
def to_role_and_content(field):
return {
"conversation": [
{"role": "user", "content": field["conversation"][0]["input"]},
{"role": "assistant", "content": field["conversation"][0]["output"]},
]
}
def _pc(sample):
return conversations_formatting_function(processing_class, messages_field="conversation")(to_role_and_content(sample))
return _pc
# fmt:on
[docs]def create_constant_length_dataset(
processing_class,
dataset,
dataset_text_field: tp.Optional[str] = None,
formatting_func: tp.Optional[tp.Callable] = None,
infinite: bool = False,
seq_length: int = 1024,
num_of_sequences: int = 1024,
chars_per_token: float = 3.6,
eos_token_id: int = 0,
shuffle: bool = True,
append_concat_token: bool = True,
add_special_tokens: bool = True,
) -> tp.Callable[[], tp.Iterator[tp.Dict[str, jnp.ndarray]]]:
"""
Creates a generator function that yields constant length chunks of tokens from a stream of text files.
Args:
processing_class: The processor used for processing the data.
dataset: Dataset with text files.
dataset_text_field: Name of the field in the dataset that contains the text.
formatting_func: Function that formats the text before tokenization.
infinite: If True the iterator is reset after dataset reaches end else stops.
seq_length: Length of token sequences to return.
num_of_sequences: Number of token sequences to keep in buffer.
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id: Id of the end of sequence token if the passed processing_class does not have an EOS token.
shuffle: Shuffle the examples before they are returned.
append_concat_token: If true, appends eos_token_id at the end of each sample being packed.
add_special_tokens: If true, processing_class adds special tokens to each sample being packed.
Returns:
A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays
"""
if processing_class.eos_token_id is None:
warnings.warn(
"The passed processing_class does not have an EOS token. We will use the passed eos_token_id instead which "
f"corresponds to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id.",
stacklevel=1,
)
concat_token_id = (
processing_class.eos_token_id if processing_class.eos_token_id else eos_token_id
)
max_buffer_size = seq_length * chars_per_token * num_of_sequences
# Input validation and formatting function setup
if dataset_text_field is not None and formatting_func is not None:
warnings.warn(
"Only one of `dataset_text_field` and `formatting_func` should be provided. "
"Ignoring `dataset_text_field` and using `formatting_func`.",
stacklevel=1,
)
if formatting_func is not None:
if formatting_func.__code__.co_argcount > 1:
warnings.warn(
"The passed formatting_func has more than one argument. Usually that function should have a single argument "
"`example` which corresponds to the dictionary returned by each element of the dataset. Make sure you know "
"what you are doing.",
stacklevel=1,
)
elif dataset_text_field is not None:
formatting_func = lambda x: x[dataset_text_field] # noqa
else:
raise ValueError(
"Either `dataset_text_field` or `formatting_func` should be provided."
)
def constant_length_generator() -> tp.Iterator[tp.Dict[str, jnp.ndarray]]:
iterator = iter(dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
# Fill the buffer
while True:
if buffer_len >= max_buffer_size:
break
try:
prompt = formatting_func(next(iterator))
if isinstance(prompt, list):
prompt = "".join(p for p in prompt)
buffer.append(prompt)
buffer_len += len(buffer[-1])
except StopIteration:
if infinite:
iterator = iter(dataset)
warnings.warn(
"The dataset reached end and the iterator is reset to the start.",
stacklevel=1,
)
else:
more_examples = False
break
if shuffle:
random.shuffle(buffer)
# Tokenize all texts in the buffer
tokens = processing_class(
text=buffer,
add_special_tokens=add_special_tokens,
truncation=False,
)
tokenized_inputs = tokens["input_ids"]
attention_masks = tokens["attention_mask"]
# Concatenate all tokens and attention masks
all_token_ids = []
all_attention_masks = []
for tokenized_input, attention_mask in zip(tokenized_inputs, attention_masks):
if append_concat_token:
tokenized_input = tokenized_input + [concat_token_id]
attention_mask = attention_mask + [1]
all_token_ids.extend(tokenized_input)
all_attention_masks.extend(attention_mask)
# Create fixed-length examples
examples = []
examples_attention_masks = []
for i in range(0, len(all_token_ids), seq_length):
input_ids = all_token_ids[i : i + seq_length]
org_attention_masks = all_attention_masks[i : i + seq_length]
if len(input_ids) == seq_length:
examples.append(input_ids)
examples_attention_masks.append(org_attention_masks)
if shuffle:
# Shuffle examples while keeping pairs together
combined = list(zip(examples, examples_attention_masks))
random.shuffle(combined)
examples, examples_attention_masks = zip(*combined)
# Yield examples
for example, example_attention_mask in zip(examples, examples_attention_masks):
yield {
"input_ids": jnp.asarray(example, dtype="i4"),
"attention_mask": jnp.asarray(example_attention_mask, dtype="i4"),
}
return constant_length_generator
def _collate_batch(
examples,
processing_class,
pad_to_multiple_of: tp.Optional[int] = None,
):
if isinstance(examples[0], (list, tuple)):
examples = [jnp.array(e, dtype=jnp.int64) for e in examples]
length_of_first = len(examples[0])
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
if are_tensors_same_length and (
pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0
):
return jnp.stack(examples, axis=0)
if processing_class._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the processing_class you are using"
f" ({processing_class.__class__.__name__}) does not have a pad token."
)
max_length = max(len(x) for x in examples)
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
result = jnp.full(
shape=(len(examples), max_length),
fill_value=processing_class.pad_token_id,
dtype=examples[0].dtype,
)
for i, example in enumerate(examples):
if processing_class.padding_side == "right":
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result
[docs]def tolist(x):
"""from HF
Args:
x:
Returns: X as tp.List
"""
if isinstance(x, list):
return x
elif hasattr(x, "numpy"):
x = x.numpy()
return x.tolist()
[docs]class DataCollatorForCompletionOnlyLM:
"""Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensures that the loss is only
calculated on the completion made by the assistant.
"""
def __init__(
self,
processing_class: tp.Union[str, "PreTrainedTokenizerBase"], # type:ignore #noqa
response_template: tp.Union[str, tp.List[int]],
instruction_template: tp.Optional[tp.Union[str, tp.List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
from transformers import AutoTokenizer
if isinstance(processing_class, str):
processing_class = AutoTokenizer.from_pretrained(processing_class)
self.processing_class = processing_class
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
self.instruction_token_ids = self.processing_class.encode(
self.instruction_template, add_special_tokens=False
)
else:
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
self.response_token_ids = self.processing_class.encode(
self.response_template, add_special_tokens=False
)
else:
self.response_token_ids = response_template
if (
not mlm
and self.instruction_template
and self.processing_class.pad_token_id == self.processing_class.eos_token_id
):
warnings.warn(
"The pad_token_id and eos_token_id values of this processing_class are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value.",
stacklevel=1,
)
self.ignore_index = ignore_index
def _whole_word_mask(self, input_tokens: tp.List[str], max_predictions=512):
from transformers import BertTokenizer, BertTokenizerFast
if not isinstance(self.processing_class, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
"Please refer to the documentation for more information.",
stacklevel=1,
)
cand_indexes = []
for i, token in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * 0.15))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
if len(covered_indexes) != len(masked_lms):
raise ValueError(
"Length of covered_indexes is not equal to length of masked_lms."
)
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
[docs] def jax_mask_tokens(
self, inputs: tp.Any, special_tokens_mask: tp.Optional[tp.Any] = None
) -> tp.Tuple[tp.Any, tp.Any]:
"""Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""
labels = np.copy(inputs)
probability_matrix = np.full(labels.shape, 0.15)
if special_tokens_mask is None:
special_tokens_mask = [
self.processing_class.get_special_tokens_mask(
val, already_has_special_tokens=True
)
for val in labels.tolist()
]
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
else:
special_tokens_mask = special_tokens_mask.astype(bool)
probability_matrix[special_tokens_mask] = 0
masked_indices = np.random.binomial(
1, probability_matrix, size=probability_matrix.shape
).astype(bool)
labels[~masked_indices] = -100
indices_replaced = (
np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
)
inputs[indices_replaced] = self.processing_class.mask_token_id
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(
low=0,
high=len(self.processing_class),
size=np.count_nonzero(indices_random),
dtype=np.int64,
)
inputs[indices_random] = random_words
return inputs, labels
[docs] def jax_call(
self, examples: tp.List[tp.Union[tp.List[int], tp.Any, tp.Dict[str, tp.Any]]]
) -> tp.Dict[str, tp.Any]:
if isinstance(examples[0], tp.Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = _collate_batch(
input_ids,
self.processing_class,
)
mask_labels = []
for e in examples:
ref_tokens = []
for ida in tolist(e["input_ids"]):
token = self.processing_class._convert_id_to_token(ida)
ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _collate_batch(
mask_labels,
self.processing_class,
)
inputs, labels = self.jax_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def __call__(
self, examples: tp.List[tp.Union[tp.List[int], tp.Any, tp.Dict[str, tp.Any]]]
) -> tp.Dict[str, tp.Any]:
batch = self.jax_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in jnp.where(batch["labels"][i] == self.response_token_ids[0])[0]:
if (
self.response_token_ids
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(
self.response_token_ids
)
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in jnp.where(
batch["labels"][i] == self.response_token_ids[0]
)[0]:
if (
self.response_token_ids
== batch["labels"][i][
assistant_idx : assistant_idx + len(self.response_token_ids)
].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in jnp.where(batch["labels"][i] == human_token_ids[0])[0]:
if (
human_token_ids
== batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist()
):
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f"following instance: {self.processing_class.decode(batch['input_ids'][i])} "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`.",
stacklevel=1,
)
batch["labels"][i, :] = self.ignore_index
if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(
zip(human_token_ids_idxs, response_token_ids_idxs)
):
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
return batch
[docs]@auto_pytree
class RewardDataCollatorWithPadding:
r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch.
Args:
tokenizer (`ProcessingClassType`):
The tokenizer used for encoding the data.
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
padding_strategy to pass to the tokenizer.
max_length (`int` or `None`, `optional`, defaults to `None`):
If set will pad the sequence to a maximum provided value.
"""
tokenizer: ProcessingClassType
padding: tp.Union[bool, str] = "max_length"
max_length: tp.Optional[int] = None
truncation_mode: str = "keep_end"
def __call__(self, features: list[dict[str, tp.Any]]) -> dict[str, tp.Any]:
features_chosen = []
features_rejected = []
margin = []
has_margin = "margin" in features[0]
for feature in features:
if (
"input_ids_chosen" not in feature
or "input_ids_rejected" not in feature
or "attention_mask_chosen" not in feature
or "attention_mask_rejected" not in feature
):
raise ValueError(
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
)
features_chosen.append(
{
"input_ids": feature["input_ids_chosen"],
"attention_mask": feature["attention_mask_chosen"],
}
)
features_rejected.append(
{
"input_ids": feature["input_ids_rejected"],
"attention_mask": feature["attention_mask_rejected"],
}
)
if has_margin:
margin.append(feature["margin"])
batch_chosen = self.tokenizer.pad(
features_chosen,
padding=self.padding,
max_length=self.max_length,
return_tensors="jax",
)
batch_rejected = self.tokenizer.pad(
features_rejected,
padding=self.padding,
max_length=self.max_length,
return_tensors="jax",
)
batch = {
"input_ids_chosen": batch_chosen["input_ids"],
"attention_mask_chosen": batch_chosen["attention_mask"],
"input_ids_rejected": batch_rejected["input_ids"],
"attention_mask_rejected": batch_rejected["attention_mask"],
}
if has_margin:
margin = jnp.array(margin, dtype="f4")
batch["margin"] = margin
return batch
[docs]@auto_pytree
class DataCollatorForPreference:
r"""DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch."""
max_prompt_length: int
max_completion_length: int
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: tp.Optional[bool] = False
def __call__(self, features: tp.List[tp.Dict[str, tp.Any]]) -> tp.Dict[str, tp.Any]:
prompt_input_ids = [jnp.array(feature["prompt_input_ids"]) for feature in features]
prompt_attention_mask = [jnp.ones_like(input_ids) for input_ids in prompt_input_ids]
chosen_input_ids = [jnp.array(feature["chosen_input_ids"]) for feature in features]
chosen_attention_mask = [jnp.ones_like(input_ids) for input_ids in chosen_input_ids]
rejected_input_ids = [
jnp.array(feature["rejected_input_ids"]) for feature in features
]
rejected_attention_mask = [
jnp.ones_like(input_ids) for input_ids in rejected_input_ids
]
pixel_values = None
pixel_attention_mask = None
if "pixel_values" in features[0]:
pixel_values = [jnp.array(feature["pixel_values"]) for feature in features]
if "pixel_attention_mask" in features[0]:
pixel_attention_mask = [
jnp.array(feature["pixel_attention_mask"]) for feature in features
]
ref_chosen_logps = None
ref_rejected_logps = None
if "ref_chosen_logps" in features[0] and "ref_rejected_logps" in features[0]:
ref_chosen_logps = jnp.array(
[feature["ref_chosen_logps"] for feature in features]
)
ref_rejected_logps = jnp.array(
[feature["ref_rejected_logps"] for feature in features]
)
# Pad sequences
output = {
"prompt_input_ids": pad(
prompt_input_ids,
self.max_prompt_length,
padding_value=self.pad_token_id,
padding_side="left",
),
"prompt_attention_mask": pad(
prompt_attention_mask,
self.max_prompt_length,
padding_value=0,
padding_side="left",
),
"chosen_input_ids": pad(
chosen_input_ids,
self.max_completion_length,
padding_value=self.pad_token_id,
),
"chosen_attention_mask": pad(
chosen_attention_mask,
self.max_completion_length,
padding_value=0,
),
"rejected_input_ids": pad(
rejected_input_ids,
self.max_completion_length,
padding_value=self.pad_token_id,
),
"rejected_attention_mask": pad(
rejected_attention_mask,
self.max_completion_length,
padding_value=0,
),
}
# Add optional outputs
if pixel_values is not None:
output["pixel_values"] = pad(
pixel_values,
self.max_prompt_length,
padding_value=0.0,
)
if pixel_attention_mask is not None:
output["pixel_attention_mask"] = pad(
pixel_attention_mask,
self.max_prompt_length,
padding_value=0,
)
if "image_sizes" in features[0]:
output["image_sizes"] = jnp.array(
[feature["image_sizes"] for feature in features]
)
if ref_chosen_logps is not None and ref_rejected_logps is not None:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
return output
[docs]@auto_pytree
class DPODataCollatorWithPadding:
r"""
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
"""
max_prompt_length: int
max_completion_length: int
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: tp.Optional[bool] = False
output_arrays_only: bool = True
prepadded: bool = True
def __call__(self, features: list[dict[str, tp.Any]]) -> dict[str, tp.Any]:
camax_length = self.max_completion_length + self.max_prompt_length
padded_batch = {}
for k in features[0].keys():
if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
match k.split("_")[0]:
case "rejected":
max_length = self.max_completion_length
case "chosen":
max_length = self.max_completion_length
case "prompt":
max_length = self.max_prompt_length
case _:
max_length = camax_length
if self.is_encoder_decoder:
to_pad = [jnp.array(ex[k], dtype="i4") for ex in features]
if (k.startswith("prompt")) and (k.endswith("input_ids")):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(
to_pad,
batch_first=False,
padding_value=padding_value,
max_len=None if self.prepadded else max_length,
)
else:
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.endswith("_pixel_values"):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
if k in ["prompt_input_ids", "prompt_attention_mask"]:
padding_side = "left"
else:
padding_side = "right"
if k.endswith("_pixel_values"):
dtype = jnp.float32
else:
dtype = jnp.int32
to_pad = [jnp.array(ex[k], dtype=dtype) for ex in features]
padded_batch[k] = pad(
to_pad,
None if self.prepadded else max_length,
padding_value=padding_value,
padding_side=padding_side,
)
elif k.endswith("_logps"):
padded_batch[k] = jnp.array([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
if self.output_arrays_only:
val = padded_batch.get(k)
if hasattr(val, "dtype"):
if val.dtype not in [
jnp.float64,
jnp.float32,
jnp.float16,
jnp.int32,
jnp.int16,
jnp.int8,
]:
padded_batch.pop(k)
else:
padded_batch.pop(k)
return padded_batch
[docs]def shift_and_pad(mask, *tensors):
for i in range(mask.shape[0]):
first_one_idx = np.nonzero(mask[i])[0][0].item()
mask[i] = np.roll(mask[i], shift=-first_one_idx)
for tensor in tensors:
tensor[i] = np.roll(tensor[i], shift=-first_one_idx)
if not tensors:
return mask
else:
return mask, *tensors
[docs]def pad(
tensors: list[jnp.ndarray],
max_lenght: tp.Optional[int],
padding_value: int = 0,
padding_side: str = "right",
) -> jnp.ndarray:
"""
Pads a list of tensors to the same shape along the first dimension.
"""
output_shape = tensors[0].shape[:-1]
current_max = tensors[0].shape[-1]
if max_lenght is None:
max_lenght = current_max
x_lenght = max(current_max, max_lenght)
output_shape += (x_lenght,)
output = jnp.full(
(len(tensors), *output_shape),
padding_value,
dtype=tensors[0].dtype,
)
for i, t in enumerate(tensors):
if padding_side == "left":
seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
elif padding_side == "right":
seq_slice = slice(0, t.shape[0])
else:
raise ValueError("padding_side must be 'left' or 'right'")
slices = (i,) + (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
output = output.at[slices].set(t)
if padding_side == "left":
output = output[..., -max_lenght:]
elif padding_side == "right":
output = output[..., :max_lenght]
else:
raise ValueError("padding_side must be 'left' or 'right'")
return output
[docs]def pad_to_length(
tensor: chex.Array,
length: int,
pad_value: tp.Union[int, float],
axis: int = -1,
) -> chex.Array:
if tensor.shape[axis] >= length:
if tensor.ndim == 2:
tensor = tensor[:, :length]
return tensor
else:
pad_size = list(tensor.shape)
pad_size[axis] = length - tensor.shape[axis]
return jax.numpy.concatenate(
[
tensor,
pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype),
],
axis=axis,
)
[docs]def pad_sequence(
sequences,
batch_first=False,
padding_value=0,
max_len: int | None = None,
):
max_len = max(seq.shape[-1] for seq in sequences) if max_len is None else max_len
padding_value = jnp.array(padding_value).reshape(1)
if batch_first:
padded_seqs = [
(
jnp.concatenate(
[
seq.reshape(1, -1),
jnp.ones((1, max_len - seq.shape[-1])) * padding_value,
],
axis=1,
)
if seq.shape[-1] < max_len
else seq.reshape(1, -1)
)
for seq in sequences
]
else:
padded_seqs = [
(
jnp.concatenate(
[
jnp.ones((1, max_len - seq.shape[-1])) * padding_value,
seq.reshape(1, -1),
],
axis=1,
)
if seq.shape[-1] < max_len
else seq.reshape(1, -1)
)
for seq in sequences
]
return jnp.array(padded_seqs)
[docs]@contextmanager
def leave_alone_context_manager():
# Perform setup actions (none in this case)
yield
[docs]def add_bos_token_if_needed(
bos_token_id: tp.Optional[int],
prompt_len_input_ids: int,
prompt_tokens: tp.Dict[str, tp.List[int]],
chosen_prompt_len_input_ids: int,
chosen_tokens: tp.Dict[str, tp.List[int]],
rejected_prompt_len_input_ids: int,
rejected_tokens: tp.Dict[str, tp.List[int]],
):
if bos_token_id is not None:
if (
prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]
):
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens[
"prompt_input_ids"
]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens[
"prompt_attention_mask"
]
if (
chosen_prompt_len_input_ids == 0
or bos_token_id != chosen_tokens["prompt_input_ids"][0]
):
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens[
"prompt_input_ids"
]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens[
"prompt_attention_mask"
]
if (
rejected_prompt_len_input_ids == 0
or bos_token_id != rejected_tokens["prompt_input_ids"][0]
):
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens[
"prompt_input_ids"
]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens[
"prompt_attention_mask"
]
return prompt_tokens, chosen_tokens, rejected_tokens
[docs]def add_eos_token_if_needed(
eos_token_id: int,
chosen_tokens: tp.Dict[str, tp.List[int]],
rejected_tokens: tp.Dict[str, tp.List[int]],
):
if (
len(chosen_tokens["input_ids"]) == 0
or eos_token_id != chosen_tokens["input_ids"][-1]
):
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if (
len(rejected_tokens["input_ids"]) == 0
or eos_token_id != rejected_tokens["input_ids"][-1]
):
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
return chosen_tokens, rejected_tokens
[docs]def first_true_indices(bools, dtype=jnp.int32):
"""
Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving
the position of the first True in each "row".
"""
row_len = bools.shape[-1]
zero_or_index = row_len * (~bools).astype(dtype) + jnp.arange(row_len, dtype=dtype)
return jnp.min(zero_or_index, axis=-1)
[docs]def truncate_right(input_ids, stop_token_id, pad_token_id):
"""
Truncates the input array from the right side after the first occurrence of the stop token.
"""
trunc_idxs = first_true_indices(input_ids == stop_token_id).reshape((-1, 1))
idxs = jnp.arange(input_ids.shape[1]).reshape((1, -1))
output_ids = jnp.where(idxs > trunc_idxs, pad_token_id, input_ids)
mask = jnp.where(idxs > trunc_idxs, 0, 1)
return output_ids, mask
[docs]@partial(jax.jit, static_argnums=(1,))
def compute_weight_stats(params, repattern: str):
"""Compute statistics for model weights in a JIT-compatible way.
Args:
params: Model parameters
repattern: parameters to analyze
Returns:
Dictionary of weight statistics
"""
stats = {}
for path, weight in traversals.flatten_dict(params).items():
weight = weight.value
pattern_search = ".".join([str(p) for p in path])
path = "/".join([str(p) for p in path])
if bool(re.match(repattern, pattern_search)):
stats[f"{path}/values"] = weight.flatten()
stats[f"{path}/mean"] = jnp.mean(weight)
stats[f"{path}/std"] = jnp.std(weight)
stats[f"{path}/min"] = jnp.min(weight)
stats[f"{path}/max"] = jnp.max(weight)
return stats