fix rlhf callback

Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
hiyouga
2023-11-16 03:26:19 +08:00
parent e017266b98
commit de3a84ac59
4 changed files with 19 additions and 12 deletions

View File

@@ -45,6 +45,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
@@ -72,8 +73,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
else:
self.reward_model = None
def ppo_train(self) -> None:
r"""