@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@@ -9,22 +8,40 @@ class FreezeArguments:
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
|
||||
name_module_trainable: str = field(
|
||||
default="all",
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
InternLM2 choices: ["feed_forward", "attention"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
"help": (
|
||||
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
|
||||
"Positive numbers mean the last n layers are set as trainable, "
|
||||
"negative numbers mean the first n layers are set as trainable."
|
||||
)
|
||||
},
|
||||
)
|
||||
num_layer_trainable: int = field(
|
||||
default=2,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
freeze_trainable_modules: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the available modules. "
|
||||
"LLaMA choices: [`mlp`, `self_attn`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
|
||||
"Qwen choices: [`mlp`, `attn`], "
|
||||
"InternLM2 choices: [`feed_forward`, `attention`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules apart from hidden layers to be set as trainable "
|
||||
"for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +54,11 @@ class LoraArguments:
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||
"help": (
|
||||
"Name(s) of modules apart from LoRA layers to be set as trainable "
|
||||
"and saved in the final checkpoint. "
|
||||
"Use commas to separate multiple modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
@@ -55,15 +76,17 @@ class LoraArguments:
|
||||
lora_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name(s) of target modules to apply LoRA. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the linear modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
"help": (
|
||||
"Name(s) of target modules to apply LoRA. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules. "
|
||||
"LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
|
||||
"Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
|
||||
"InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
@@ -177,8 +200,10 @@ class GaloreArguments:
|
||||
galore_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the linear modules."""
|
||||
"help": (
|
||||
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
@@ -238,16 +263,20 @@ class BAdamArgument:
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={
|
||||
"help": """The mode of the mask for BAdam optimizer. \
|
||||
`adjacent` means that the trainable parameters are adjacent to each other, \
|
||||
`scatter` means that trainable parameters are randomly choosed from the weight."""
|
||||
"help": (
|
||||
"The mode of the mask for BAdam optimizer. "
|
||||
"`adjacent` means that the trainable parameters are adjacent to each other, "
|
||||
"`scatter` means that trainable parameters are randomly choosed from the weight."
|
||||
)
|
||||
},
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": """The verbosity level of BAdam optimizer. \
|
||||
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
|
||||
"help": (
|
||||
"The verbosity level of BAdam optimizer. "
|
||||
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -285,7 +314,8 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
@@ -315,17 +345,3 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
return cls(**json.loads(text))
|
||||
|
||||
Reference in New Issue
Block a user