Former-commit-id: ff9a3f73961a362d0ddc22079f80a85465fffda8
This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 85726c91ce
commit 1dc963caa6
4 changed files with 23 additions and 15 deletions

View File

@@ -353,7 +353,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)