Former-commit-id: bccc71259e703ca1e1d88169e385a026c4efa92e
This commit is contained in:
hiyouga
2023-11-30 21:02:00 +08:00
parent 664267e050
commit 8ed68301e3
2 changed files with 4 additions and 2 deletions

View File

@@ -40,7 +40,8 @@ class PairwiseTrainer(Trainer):
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected