support SimPO #3900

Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
This commit is contained in:
hiyouga
2024-05-26 23:46:33 +08:00
parent 26f293d587
commit b0d9966663
19 changed files with 145 additions and 339 deletions

View File

@@ -4,6 +4,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
@@ -50,10 +51,11 @@ class CustomDPOTrainer(DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# dpo hyperparams
self.beta = finetuning_args.dpo_beta
self.beta = finetuning_args.pref_beta
self.loss_type = finetuning_args.pref_loss
self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self.simpo_gamma = finetuning_args.simpo_gamma
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@@ -90,15 +92,66 @@ class CustomDPOTrainer(DPOTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
def sft_loss(self, batch: Dict[str, "torch.Tensor"], chosen_logits: "torch.FloatTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
chosen_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -chosen_logps
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
sft_loss = -chosen_logps
odds_ratio_loss = -F.logsigmoid(log_odds)
orpo_loss = sft_loss + self.beta * odds_ratio_loss
return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios
simpo_loss = -F.logsigmoid(self.beta * logits)
return simpo_loss
def compute_preference_loss(
self,
policy_chosen_logps: "torch.Tensor",
policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes loss for preference learning.
"""
if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
return losses, chosen_rewards, rejected_rewards
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
@@ -108,13 +161,15 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,
labels=batch_copied["labels"],
average_log_prob=(self.loss_type == "ipo"),
labels=batch["labels"],
average_log_prob=(self.loss_type in ["ipo", "orpo", "simpo"]),
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
@@ -123,6 +178,32 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
"""
if not self.finetuning_args.use_ref_model:
return None, None
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
@@ -140,32 +221,16 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
sft_loss = self.sft_loss(batch, policy_chosen_logits) # compute chosen_logps with masks
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
@@ -178,5 +243,8 @@ class CustomDPOTrainer(DPOTrainer):
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
return losses.mean(), metrics