update trainers

Former-commit-id: b7f6c4a171293cf4f3e88f15a811f847342f84ee
This commit is contained in:
hiyouga
2024-06-06 18:45:49 +08:00
parent f4acd81e2f
commit d5559461c1
4 changed files with 12 additions and 21 deletions

View File

@@ -187,13 +187,7 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps