support longlora for main branch
Former-commit-id: f869501ad4c368df26534c41f62c6d63c6be17dd
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
@@ -93,7 +94,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
average_log_prob=False,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
@@ -118,20 +120,19 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
) = self.concatenated_forward(model, batch)
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.model, batch)
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.ref_model, batch)
|
||||
) = self.concatenated_forward(ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps,
|
||||
|
||||
Reference in New Issue
Block a user