fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
@@ -45,6 +45,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.args = training_args
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reward_model = reward_model
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
@@ -72,8 +73,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
else:
|
||||
self.reward_model = None
|
||||
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user