support lora for llama pro

Former-commit-id: f74c78ba95f0545aae89e603e466f494705ad024
This commit is contained in:
hiyouga
2024-02-21 02:17:22 +08:00
parent a3f30038a0
commit bc16c9a54a
7 changed files with 119 additions and 28 deletions

View File

@@ -26,10 +26,6 @@ class FreezeArguments:
default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
)
use_llama_pro: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
)
@dataclass
@@ -170,6 +166,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
use_llama_pro: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
disable_version_checking: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable version checking."},
@@ -195,13 +195,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
raise ValueError("`reward_model` is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.use_llama_pro and self.finetuning_type != "freeze":
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""