update trainers

Former-commit-id: b7f6c4a171293cf4f3e88f15a811f847342f84ee
This commit is contained in:
hiyouga
2024-06-06 18:45:49 +08:00
parent f4acd81e2f
commit d5559461c1
4 changed files with 12 additions and 21 deletions

View File

@@ -146,15 +146,8 @@ class CustomKTOTrainer(KTOTrainer):
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
chosen_logps = target_logps[chosen_idx, ...]
rejected_logps = target_logps[rejected_idx, ...]
chosen_logits = target_logits[chosen_idx, ...]
rejected_logits = target_logits[rejected_idx, ...]
chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]]
chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def compute_reference_log_probs(