add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
@@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"input_ids": batch["{}input_ids".format(prefix)],
|
||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||
}
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
return logps, logps / valid_length
|
||||
|
||||
|
||||
Reference in New Issue
Block a user