improve KTO impl., replace datasets

Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
hiyouga
2024-05-18 03:44:56 +08:00
parent e4570e28a8
commit 2bff90719b
53 changed files with 448 additions and 330 deletions

View File

@@ -1,12 +1,12 @@
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = [
"PairwiseDataCollatorWithPadding",
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .utils import Role
@@ -14,7 +15,13 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
@@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
@@ -45,23 +55,33 @@ def convert_alpaca(
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else:
if examples[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
return outputs
@@ -69,6 +89,9 @@ def convert_alpaca(
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
@@ -88,21 +111,62 @@ def convert_sharegpt(
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
@@ -138,7 +202,6 @@ def align_dataset(
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
"tag": {"dtype": "bool", "_type": "Value"},
}
)
kwargs = {}

View File

@@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features, return_tensors=None):
concatenated_features = []
kl_concatenated_features = []
tags = []
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
concatenated_features.append(
target_features.append(
{
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
)
kl_concatenated_features.append(
kl_features.append(
{
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
)
tags.append(feature["tag"])
batch = super().__call__(concatenated_features)
kl_batch = super().__call__(kl_concatenated_features)
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
batch["tag"] = torch.tensor(tags)
return batch
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@@ -57,7 +57,7 @@ def load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
raise ValueError("File {} not found.".format(local_path))
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
@@ -116,7 +116,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:

View File

@@ -25,21 +25,22 @@ class DatasetAttr:
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
""" common columns """
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
tag: Optional[bool] = None
""" columns for the alpaca format """
""" rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
chosen: Optional[str] = "chosen"
rejected: Optional[str] = "rejected"
history: Optional[str] = None
""" columns for the sharegpt format """
""" sharegpt columns """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
""" sharegpt tags """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
@@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system", "images", "tag"]
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
@@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["tag"].append(examples["tag"])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
examples['kl_response'] = examples['response'][::-1]
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
input_ids, labels = [], []
kl_input_ids, kl_labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
kl_input_ids += source_ids + target_ids
kl_labels += source_mask + target_ids
if template.efficient_eos:
kl_input_ids += [tokenizer.eos_token_id]
kl_labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["tag"].append(examples["tag"][i])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
if desirable == 0 or undesirable == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
@@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
messages = examples["prompt"][i] + [examples["response"][i][0]]
else: # undesired example
kto_tag = False
messages = examples["prompt"][i] + [examples["response"][i][1]]
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],

View File

@@ -137,21 +137,21 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the KTO loss."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
kto_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
)
kto_desirable_weight: float = field(
default=1.0,
metadata={"help": "The desirable weight for the KTO loss."},
)
kto_undesirable_weight: float = field(
default=1.0,
metadata={"help": "The undesirable weight for the KTO loss."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
@@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)

View File

@@ -47,11 +47,13 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# dpo hyperparams
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@@ -143,6 +145,7 @@ class CustomDPOTrainer(DPOTrainer):
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model

View File

@@ -1,7 +1,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import torch
from transformers import Trainer
@@ -13,7 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
@@ -24,6 +24,7 @@ class CustomKTOTrainer(KTOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
):
@@ -33,6 +34,7 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@@ -43,15 +45,15 @@ class CustomKTOTrainer(KTOTrainer):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# KTO parameter
# kto hyperparams
self.beta = finetuning_args.kto_beta
self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.kto_ftx
self.desirable_weight = finetuning_args.kto_desirable_weight
self.undesirable_weight = finetuning_args.kto_undesirable_weight
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@@ -82,78 +84,85 @@ class CustomKTOTrainer(KTOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps.nanmean()
return -all_logps
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad():
KL_logits = model(
batch["KL_completion_input_ids"],
attention_mask=batch["KL_completion_attention_mask"],
).logits
kl_logits = model(
input_ids=batch["kl_input_ids"],
attention_mask=batch["kl_attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logits = model(
batch["input_ids"],
target_logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logps = self.get_batch_logps(
completion_logits,
batch["labels"],
target_logps = self.get_batch_logps(
logits=target_logits,
labels=batch["labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
KL_logps = self.get_batch_logps(
KL_logits,
batch["kl_labels"],
kl_logps = self.get_batch_logps(
logits=kl_logits,
labels=batch["kl_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if completion_logps.shape[0] != len(batch["tag"]):
raise ValueError(
"There is a mismatch between the number of examples in this batch and the number of "
"examples for which an output sequence was predicted."
)
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_logps = completion_logps[chosen_idx, ...]
rejected_logps = completion_logps[rejected_idx, ...]
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
chosen_logits = completion_logits[chosen_idx, ...]
rejected_logits = completion_logits[rejected_idx, ...]
chosen_logps = target_logps[chosen_idx, ...]
rejected_logps = target_logps[rejected_idx, ...]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
chosen_logits = target_logits[chosen_idx, ...]
rejected_logits = target_logits[rejected_idx, ...]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
):
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
_,
policy_kl_logps,
) = self.forward(model, batch)
with torch.no_grad():
@@ -163,27 +172,29 @@ class CustomKTOTrainer(KTOTrainer):
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
reference_kl_logps,
) = self.forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
policy_kl_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_KL_logps,
reference_kl_logps,
)
losses = losses.nanmean()
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
@@ -203,4 +214,4 @@ class CustomKTOTrainer(KTOTrainer):
metrics["kl"] = kl.item()
return losses, metrics
return losses, metrics

View File

@@ -48,9 +48,9 @@ def run_kto(
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)

View File

@@ -29,7 +29,7 @@ def run_ppo(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@@ -9,12 +9,13 @@ from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .dpo import run_dpo
from .kto import run_kto
from .orpo import run_orpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .kto import run_kto
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -37,10 +38,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")