support longlora for main branch

Former-commit-id: f869501ad4c368df26534c41f62c6d63c6be17dd
This commit is contained in:
hiyouga
2024-01-20 19:25:22 +08:00
parent 8efc055511
commit 80637fc06d
7 changed files with 168 additions and 204 deletions

View File

@@ -1,4 +1,5 @@
import torch
from contextlib import nullcontext
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
@@ -93,7 +94,8 @@ class CustomDPOTrainer(DPOTrainer):
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
@@ -118,20 +120,19 @@ class CustomDPOTrainer(DPOTrainer):
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,