# FROM HF TRL
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import typing as tp
from datasets import Dataset, DatasetDict
from easydel.infra.utils import ProcessingClassType
DatasetType = tp.TypeVar("DatasetType", Dataset, DatasetDict)
InputDict = tp.Dict[str, str]
InputListDict = tp.List[InputDict]
InputListListDict = tp.List[tp.List[InputDict]]
InputType = tp.Union[InputListListDict, InputListDict, InputDict]
OpenAIMessageContentPart = tp.Dict[str, str]
OpenAIMessage = tp.Dict[str, tp.Union[str, tp.List[OpenAIMessageContentPart]]]
OutputDict = tp.Dict[str, str]
OutputListDict = tp.List[OutputDict]
OutputType = tp.Union[OutputDict, OutputListDict, None]
OpenAIMessageList = tp.List[OpenAIMessage]
def _is_valid_openai_message_list(data: tp.Any) -> bool:
"""
Checks if the input data strictly conforms to the OpenAIMessageList format
where content is specifically a list of parts (e.g., [{"type": "text", ...}]).
"""
if not isinstance(data, list):
return False
if not data:
return True
for item in data:
if not isinstance(item, dict):
return False
if "role" not in item or "content" not in item:
return False
if not isinstance(item.get("role"), str):
return False
content = item.get("content")
if not isinstance(content, list):
return False
for part in content:
if not isinstance(part, dict):
return False
if "type" not in part:
return False
return True
def _convert_single_dict(source_dict: InputDict) -> tp.Optional[OpenAIMessage]:
"""
Converts a single source dictionary into the target OpenAI message format.
Handles variations in keys like 'content', 'text', 'message'.
"""
if not isinstance(source_dict, dict):
print(f"Warning: Expected a dictionary, but got {type(source_dict)}. Skipping.")
return None
working_dict = copy.deepcopy(source_dict)
role = "user"
content_text = ""
role_key_found = None
for key in working_dict:
if key.lower() == "role":
role_value = working_dict[key]
if isinstance(role_value, str):
role = role_value.lower()
if role not in ["user", "assistant", "system", "tool"]:
print(f"Warning: Non-standard role '{role}' found. Using it.")
role_key_found = key
break
else:
print(
f"Warning: 'role' value is not a string ({role_value}). Using default 'user'."
)
role_key_found = key
break
if role_key_found:
del working_dict[role_key_found]
content_keys_priority = ["content", "text", "message"]
content_key_found = None
for priority_key in content_keys_priority:
for key in working_dict:
if key.lower() == priority_key:
content_value = working_dict[key]
if isinstance(content_value, str):
content_text = content_value
content_key_found = key
break
else:
print(
f"Warning: Found content key '{key}' but value is not a string ({content_value}). Trying other keys or defaulting to empty."
)
if content_key_found:
break
target_message: OpenAIMessage = {
"role": role,
"content": [{"type": "text", "text": content_text}],
}
return target_message
[docs]def is_conversational(example: dict[str, tp.Any]) -> bool:
"""
Check if the example is in a conversational format.
"""
supported_keys = ["prompt", "chosen", "rejected", "completion", "messages"]
example_keys = {key for key in example.keys() if key in supported_keys}
if example_keys:
key = example_keys.pop()
maybe_messages = example[key]
if isinstance(maybe_messages, list):
maybe_message = maybe_messages[0]
if (
isinstance(maybe_message, dict)
and "role" in maybe_message
and "content" in maybe_message
):
return True
return False
[docs]def apply_chat_template(
example: dict[str, list[dict[str, str]]],
tokenizer: ProcessingClassType,
tools: tp.Optional[list[tp.Union[dict, tp.Callable]]] = None,
) -> dict[str, str]:
r"""
Apply a chat template to a conversational example along with the schema for a list of functions in `tools`.
For more details, see [`maybe_apply_chat_template`].
"""
supported_keys = ["prompt", "chosen", "rejected", "completion", "messages", "label"]
example_keys = {key for key in example.keys() if key in supported_keys}
if example_keys not in [
{"messages"},
{"prompt"},
{"prompt", "completion"},
{"prompt", "chosen", "rejected"},
{"chosen", "rejected"},
{"prompt", "completion", "label"},
]:
raise KeyError(f"Invalid keys in the example: {example_keys}")
if "messages" in example:
messages = tokenizer.apply_chat_template(
example["messages"], tools=tools, tokenize=False
)
if "prompt" in example:
prompt = tokenizer.apply_chat_template(
example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True
)
if "prompt" in example:
if "chosen" in example:
prompt_chosen = tokenizer.apply_chat_template(
example["prompt"] + example["chosen"],
tools=tools,
tokenize=False,
)
chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example:
prompt_rejected = tokenizer.apply_chat_template(
example["prompt"] + example["rejected"],
tools=tools,
tokenize=False,
)
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"],
tools=tools,
tokenize=False,
)
completion = prompt_completion[len(prompt) :]
else:
if "chosen" in example:
chosen = tokenizer.apply_chat_template(
example["chosen"],
tools=tools,
tokenize=False,
)
if "rejected" in example:
rejected = tokenizer.apply_chat_template(
example["rejected"],
tools=tools,
tokenize=False,
)
if "prompt" in example:
error_message = (
"The chat template applied to the prompt + completion does not start with the chat template applied to "
"the prompt alone."
"\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}"
)
if "chosen" in example and not prompt_chosen.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_chosen))
if "rejected" in example and not prompt_rejected.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_rejected))
if "completion" in example and not prompt_completion.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_completion))
output = {}
if "messages" in example:
output["text"] = messages
if "prompt" in example:
output["prompt"] = prompt
if "chosen" in example:
output["chosen"] = chosen
if "rejected" in example:
output["rejected"] = rejected
if "completion" in example:
output["completion"] = completion
if "label" in example:
output["label"] = example["label"]
return output
[docs]def maybe_apply_chat_template(
example: dict[str, list[dict[str, str]]],
tokenizer: ProcessingClassType,
tools: tp.Optional[list[tp.Union[dict, tp.Callable]]] = None,
) -> dict[str, str]:
"""
If the example is in a conversational format, apply a chat template to it.
"""
if is_conversational(example):
return apply_chat_template(example, tokenizer, tools)
else:
return example
def _unpair_row(
examples: list[dict[str, list[dict[str, str]]]],
) -> list[dict[str, list[dict[str, str]]]]:
batch_size = len(examples["chosen"])
new_rows = {
"completion": examples["chosen"] + examples["rejected"],
"label": [True] * batch_size + [False] * batch_size,
}
if "prompt" in examples:
new_rows["prompt"] = examples["prompt"] + examples["prompt"]
return new_rows
[docs]def unpair_preference_dataset(
dataset: DatasetType,
num_proc: tp.Optional[int] = None,
desc: tp.Optional[str] = None,
) -> DatasetType:
"""
Unpair a preference dataset.
"""
return dataset.map(
_unpair_row,
batched=True,
remove_columns=["chosen", "rejected"],
num_proc=num_proc,
desc=desc,
)
[docs]def maybe_unpair_preference_dataset(
dataset: DatasetType,
num_proc: tp.Optional[int] = None,
desc: tp.Optional[str] = None,
) -> DatasetType:
"""
Unpair a preference dataset if it is paired.
"""
if isinstance(dataset, DatasetDict):
column_names = dataset[list(dataset.keys())[0]].column_names
else:
column_names = dataset.column_names
if "chosen" in column_names and "rejected" in column_names:
return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc)
else:
return dataset