use pre-commit

Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
hiyouga
2024-10-29 09:07:46 +00:00
parent 8f5921692e
commit 248d5daaff
66 changed files with 1028 additions and 1044 deletions

View File

@@ -129,11 +129,11 @@ class CustomKTOTrainer(KTOTrainer):
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],
}
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
@@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length
@override