Former-commit-id: 3126687c4820c34daa6a2e9e3bf9065ad59e92dc
This commit is contained in:
hiyouga
2023-11-28 20:57:24 +08:00
parent 670ee3934f
commit 0e6f4f981e
2 changed files with 4 additions and 3 deletions

View File

@@ -39,7 +39,8 @@ class PairwiseTrainer(Trainer):
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected