support galore
Former-commit-id: b67a4a46a88d83bb2a3459b3317b66cda15e0171
This commit is contained in:
@@ -157,7 +157,39 @@ class RLHFArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore optimization.
|
||||
"""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use galore optimizer."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="mlp,attn",
|
||||
metadata={"help": "Name(s) of modules to apply GaLore."},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "GaLore rank."},
|
||||
)
|
||||
galore_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||
)
|
||||
galore_scale: float = field(
|
||||
default=0.25,
|
||||
metadata={"help": "GaLore scale."},
|
||||
)
|
||||
galore_proj_type: Literal["std"] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of GaLore projection."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
@@ -203,6 +235,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
@@ -180,7 +180,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
training_args.parallel_mode.value == "distributed"
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user