[assets] update wechat (#8962)
This commit is contained in:
@@ -111,6 +111,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
if self.bco_gemma >= 1e-6:
|
||||
from trl.trainer import RunningMoments
|
||||
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
@override
|
||||
@@ -161,14 +162,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps: "torch.Tensor",
|
||||
rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: "torch.Tensor",
|
||||
reference_rejected_logps: "torch.Tensor"
|
||||
reference_rejected_logps: "torch.Tensor",
|
||||
) -> "torch.Tensor":
|
||||
chosen_logratios = chosen_logps - reference_chosen_logps
|
||||
rejected_logratios = rejected_logps - reference_rejected_logps
|
||||
chosen_rewards = self.beta * chosen_logratios
|
||||
rejected_rewards = self.beta * rejected_logratios
|
||||
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
||||
self.running.update(rewards) # update baseline
|
||||
self.running.update(rewards) # update baseline
|
||||
delta = self.running.mean
|
||||
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
|
||||
-(self.beta * rejected_logratios - delta)
|
||||
@@ -195,15 +196,12 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
else:
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
|
||||
if self.bco_gemma > 1e-6:
|
||||
bco_losses = self.bco_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
losses += bco_losses * self.bco_gemma
|
||||
|
||||
@@ -288,7 +286,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
losses += self.ftx_gamma * sft_loss
|
||||
if self.bco_gemma > 1e-6:
|
||||
# re-weigthing for MPO
|
||||
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
|
||||
losses /= self.ftx_gamma + self.bco_gemma + 1.0
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||
|
||||
Reference in New Issue
Block a user