fix bug in PPO training

Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
This commit is contained in:
hiyouga
2023-11-16 02:32:54 +08:00
parent f81a8a5e5c
commit e017266b98
3 changed files with 7 additions and 4 deletions

View File

@@ -158,9 +158,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
@@ -175,4 +180,5 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))