add rlhf-v dataset

Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
hiyouga
2024-09-01 22:57:41 +08:00
parent 7621526d22
commit 60cf12727b
12 changed files with 107 additions and 33 deletions

View File

@@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])