use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
@@ -129,11 +129,11 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch["{}input_ids".format(prefix)],
|
||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||
}
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
if f"{prefix}token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
|
||||
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
@@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
|
||||
return logps, logps / valid_length
|
||||
|
||||
@override
|
||||
|
||||
Reference in New Issue
Block a user