Former-commit-id: 392bdaf1ea9e5baf6289f2d4415a175dd55a479d
This commit is contained in:
hiyouga
2024-09-11 17:36:42 +08:00
parent 588ea95732
commit 7fd0d2fc2f
4 changed files with 12 additions and 22 deletions

View File

@@ -392,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
"""
if self.finetuning_args.reward_model_type == "api":
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
@@ -405,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")