fix rlhf callback

Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
hiyouga
2023-11-16 03:26:19 +08:00
parent e017266b98
commit de3a84ac59
4 changed files with 19 additions and 12 deletions

View File

@@ -166,7 +166,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str):