improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
@@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
def __call__(self, features, return_tensors=None):
|
||||
concatenated_features = []
|
||||
kl_concatenated_features = []
|
||||
tags = []
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
concatenated_features.append(
|
||||
target_features.append(
|
||||
{
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_concatenated_features.append(
|
||||
kl_features.append(
|
||||
{
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
tags.append(feature["tag"])
|
||||
batch = super().__call__(concatenated_features)
|
||||
kl_batch = super().__call__(kl_concatenated_features)
|
||||
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
batch["tag"] = torch.tensor(tags)
|
||||
return batch
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user