support pissa

Former-commit-id: ef8e45f2eaf466c54e9a671512a2974575677b08
This commit is contained in:
hiyouga
2024-06-16 01:08:12 +08:00
parent 05f3a3c944
commit 32f45c9e91
19 changed files with 406 additions and 76 deletions

View File

@@ -108,6 +108,18 @@ class LoraArguments:
default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
)
pissa_init: bool = field(
default=False,
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
)
pissa_iter: int = field(
default=4,
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
)
pissa_convert: bool = field(
default=False,
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
@@ -340,7 +352,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"])
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
@@ -367,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.pissa_convert and self.finetuning_type != "lora":
raise ValueError("`pissa_convert` is only valid for LoRA training.")
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")