improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
@@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
@@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["tag"].append(examples["tag"])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||
examples['kl_response'] = examples['response'][::-1]
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||
input_ids, labels = [], []
|
||||
kl_input_ids, kl_labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
kl_input_ids += source_ids + target_ids
|
||||
kl_labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
kl_input_ids += [tokenizer.eos_token_id]
|
||||
kl_labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["tag"].append(examples["tag"][i])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||
if desirable == 0 or undesirable == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
@@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
||||
Reference in New Issue
Block a user