lora modules: all by default
Former-commit-id: 52c4ae87c7f4312704c31ef26b079b2c5b95ea5f
This commit is contained in:
@@ -20,8 +20,6 @@ CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {
|
||||
@@ -80,7 +78,6 @@ class DownloadSource(str, Enum):
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
vision: bool = False,
|
||||
) -> None:
|
||||
@@ -91,8 +88,6 @@ def register_model_group(
|
||||
else:
|
||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if module is not None:
|
||||
DEFAULT_MODULE[prefix] = module
|
||||
if template is not None:
|
||||
DEFAULT_TEMPLATE[prefix] = template
|
||||
if vision:
|
||||
@@ -127,7 +122,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
@@ -151,7 +145,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
@@ -171,7 +164,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
@@ -190,7 +182,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
@@ -229,7 +220,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
@@ -245,7 +235,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3",
|
||||
)
|
||||
|
||||
@@ -344,7 +333,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
|
||||
},
|
||||
},
|
||||
module="Wqkv",
|
||||
template="dbrx",
|
||||
)
|
||||
|
||||
@@ -463,7 +451,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon",
|
||||
)
|
||||
|
||||
@@ -512,7 +499,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="glm4",
|
||||
)
|
||||
|
||||
@@ -559,7 +545,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||
},
|
||||
},
|
||||
module="wqkv",
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
@@ -581,7 +566,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
},
|
||||
module="qkv_proj",
|
||||
)
|
||||
|
||||
|
||||
@@ -868,7 +852,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||
},
|
||||
},
|
||||
module="qkv_proj",
|
||||
template="phi",
|
||||
)
|
||||
|
||||
@@ -940,7 +923,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
@@ -1153,7 +1135,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
||||
},
|
||||
},
|
||||
module="query,key_value",
|
||||
template="telechat",
|
||||
)
|
||||
|
||||
|
||||
@@ -24,12 +24,7 @@ class FreezeArguments:
|
||||
"help": (
|
||||
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the available modules. "
|
||||
"LLaMA choices: [`mlp`, `self_attn`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
|
||||
"Qwen choices: [`mlp`, `attn`], "
|
||||
"InternLM2 choices: [`feed_forward`, `attention`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
"Use `all` to specify all the available modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -79,13 +74,7 @@ class LoraArguments:
|
||||
"help": (
|
||||
"Name(s) of target modules to apply LoRA. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules. "
|
||||
"LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
|
||||
"Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
|
||||
"InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from yaml import safe_dump, safe_load
|
||||
from ..extras.constants import (
|
||||
CHECKPOINT_NAMES,
|
||||
DATA_CONFIG,
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
PEFT_METHODS,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
@@ -118,13 +117,6 @@ def get_model_info(model_name: str) -> Tuple[str, str, bool]:
|
||||
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the LoRA modules of this model.
|
||||
"""
|
||||
return DEFAULT_MODULE.get(get_prefix(model_name), "all")
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the template name if the model is a chat model.
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
|
||||
from .common import DEFAULT_CACHE_DIR, get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
@@ -159,7 +159,7 @@ class Runner:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(model_name)
|
||||
args["lora_target"] = get("train.lora_target") or "all"
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
|
||||
Reference in New Issue
Block a user