support dpo-ftx

Former-commit-id: 86dfa04f9821556019fa777106787f73eb70b452
This commit is contained in:
hiyouga
2023-12-16 19:21:41 +08:00
parent 9f77e8b025
commit d81ad2d4bc
6 changed files with 103 additions and 25 deletions

View File

@@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs
):
if disable_dropout:
@@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer):
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = beta
self.label_smoothing = 0
self.ftx_gamma = ftx_gamma
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
@@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss(
self,
chosen_logits: torch.FloatTensor,
chosen_labels: torch.LongTensor
) -> torch.Tensor:
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self._get_batch_logps(
chosen_logits,
chosen_labels,
average_log_prob=True
)
return -all_logps
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
@@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train"
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics