improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
@@ -14,7 +15,13 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
@@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
@@ -45,23 +55,33 @@ def convert_alpaca(
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
if examples[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -69,6 +89,9 @@ def convert_alpaca(
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
@@ -88,21 +111,62 @@ def convert_sharegpt(
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
@@ -138,7 +202,6 @@ def align_dataset(
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
"tag": {"dtype": "bool", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
||||
@@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
def __call__(self, features, return_tensors=None):
|
||||
concatenated_features = []
|
||||
kl_concatenated_features = []
|
||||
tags = []
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
concatenated_features.append(
|
||||
target_features.append(
|
||||
{
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_concatenated_features.append(
|
||||
kl_features.append(
|
||||
{
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
tags.append(feature["tag"])
|
||||
batch = super().__call__(concatenated_features)
|
||||
kl_batch = super().__call__(kl_concatenated_features)
|
||||
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
batch["tag"] = torch.tensor(tags)
|
||||
return batch
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
|
||||
@@ -57,7 +57,7 @@ def load_single_dataset(
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
@@ -116,7 +116,7 @@ def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
||||
@@ -25,21 +25,22 @@ class DatasetAttr:
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
tag: Optional[bool] = None
|
||||
""" columns for the alpaca format """
|
||||
""" rlhf columns """
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
chosen: Optional[str] = "chosen"
|
||||
rejected: Optional[str] = "rejected"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
""" sharegpt columns """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
""" sharegpt tags """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
@@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "images", "tag"]
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
@@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
@@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["tag"].append(examples["tag"])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||
examples['kl_response'] = examples['response'][::-1]
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||
input_ids, labels = [], []
|
||||
kl_input_ids, kl_labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
kl_input_ids += source_ids + target_ids
|
||||
kl_labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
kl_input_ids += [tokenizer.eos_token_id]
|
||||
kl_labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["tag"].append(examples["tag"][i])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||
if desirable == 0 or undesirable == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
@@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
||||
Reference in New Issue
Block a user