support galore

Former-commit-id: b67a4a46a88d83bb2a3459b3317b66cda15e0171
This commit is contained in:
hiyouga
2024-03-07 22:41:36 +08:00
parent 5d0c95bd02
commit 1e6fb6c8aa
12 changed files with 115 additions and 16 deletions

View File

@@ -157,7 +157,39 @@ class RLHFArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
class GaloreArguments:
r"""
Arguments pertaining to the GaLore optimization.
"""
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use galore optimizer."},
)
galore_target: str = field(
default="mlp,attn",
metadata={"help": "Name(s) of modules to apply GaLore."},
)
galore_rank: int = field(
default=16,
metadata={"help": "GaLore rank."},
)
galore_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the GaLore projection."},
)
galore_scale: float = field(
default=0.25,
metadata={"help": "GaLore scale."},
)
galore_proj_type: Literal["std"] = field(
default="std",
metadata={"help": "Type of GaLore projection."},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
@@ -203,6 +235,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"