fix dpo metrics
Former-commit-id: 57029280da825a39fbf5a05097921b861f126669
This commit is contained in:
@@ -131,7 +131,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
@@ -151,23 +151,25 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
|
||||
return logps, logps / valid_length
|
||||
return logits, logps, logps / valid_length
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logps, target_logps_avg = self.forward(model, batch)
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
|
||||
if len(target_logps) != len(batch["kto_tags"]):
|
||||
raise ValueError("Mismatched shape of inputs and labels.")
|
||||
|
||||
chosen_logits = target_logits[batch["kto_tags"]]
|
||||
chosen_logps = target_logps[batch["kto_tags"]]
|
||||
rejected_logits = target_logits[~batch["kto_tags"]]
|
||||
rejected_logps = target_logps[~batch["kto_tags"]]
|
||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
@@ -184,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
|
||||
reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
|
||||
ref_model, batch
|
||||
)
|
||||
|
||||
@@ -200,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
|
||||
self.concatenated_forward(model, batch)
|
||||
)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_kl_logps,
|
||||
policy_chosen_logps_avg,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
||||
model, batch
|
||||
)
|
||||
@@ -220,24 +227,21 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
sft_loss = -policy_chosen_logps_avg
|
||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
|
||||
|
||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||
num_chosen = len(chosen_rewards)
|
||||
num_rejected = len(rejected_rewards)
|
||||
if num_chosen > 0:
|
||||
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
|
||||
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
|
||||
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
|
||||
metrics["count/chosen"] = float(num_chosen)
|
||||
|
||||
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
|
||||
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
|
||||
|
||||
if all_num_chosen > 0:
|
||||
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
|
||||
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
|
||||
metrics["count/chosen"] = all_num_chosen
|
||||
|
||||
if all_num_rejected > 0:
|
||||
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
|
||||
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
|
||||
metrics["count/rejected"] = all_num_rejected
|
||||
if num_rejected > 0:
|
||||
metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
|
||||
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
|
||||
metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
|
||||
metrics["count/rejected"] = float(num_rejected)
|
||||
|
||||
metrics["kl"] = kl.item()
|
||||
|
||||
return losses, metrics
|
||||
|
||||
@override
|
||||
@@ -248,6 +252,48 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
# Add averaged stored metrics to logs
|
||||
key_list, metric_list = [], []
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
key_list.append(key)
|
||||
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
if len(metric_list) < 9: # pad to for all reduce
|
||||
for i in range(9 - len(metric_list)):
|
||||
key_list.append(f"dummy_{i}")
|
||||
metric_list.append(0.0)
|
||||
|
||||
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||
if f"count/{split}" in metric_dict:
|
||||
for key in ("rewards", "logps", "logits"):
|
||||
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
|
||||
del metric_dict[f"{key}/{split}_sum"]
|
||||
del metric_dict[f"count/{split}"]
|
||||
|
||||
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
|
||||
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||
|
||||
for key, metric in metric_dict.items(): # add remaining items
|
||||
if not key.startswith("dummy_"):
|
||||
logs[key] = metric
|
||||
|
||||
return Trainer.log(self, logs)
|
||||
|
||||
Reference in New Issue
Block a user