add rlhf-v dataset

Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
hiyouga
2024-09-01 22:57:41 +08:00
parent 7621526d22
commit 60cf12727b
12 changed files with 107 additions and 33 deletions

View File

@@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
}
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length