improve KTO impl., replace datasets

Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
hiyouga
2024-05-18 03:44:56 +08:00
parent e4570e28a8
commit 2bff90719b
53 changed files with 448 additions and 330 deletions

View File

@@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features, return_tensors=None):
concatenated_features = []
kl_concatenated_features = []
tags = []
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
concatenated_features.append(
target_features.append(
{
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
)
kl_concatenated_features.append(
kl_features.append(
{
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
)
tags.append(feature["tag"])
batch = super().__call__(concatenated_features)
kl_batch = super().__call__(kl_concatenated_features)
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
batch["tag"] = torch.tensor(tags)
return batch
batch["kto_tags"] = torch.tensor(kto_tags)
return batch