support ORPO

Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
This commit is contained in:
hiyouga
2024-03-31 18:29:50 +08:00
parent 526111a303
commit d764cd8736
22 changed files with 395 additions and 47 deletions

View File

@@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
@@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer):
return -all_logps
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
@@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
@@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer):
reference_chosen_logps,
reference_rejected_logps,
)
batch_loss = losses.mean()
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics
return batch_loss, metrics

View File

@@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model
from .collator import DPODataCollatorWithPadding
from .trainer import CustomDPOTrainer
@@ -29,7 +28,7 @@ def run_dpo(
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DPODataCollatorWithPadding(
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
@@ -64,7 +63,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
# Evaluation
if training_args.do_eval: