refactor constants

Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
This commit is contained in:
hiyouga
2023-11-10 14:16:10 +08:00
parent 68dd1ef121
commit 178b85ff9a
6 changed files with 243 additions and 83 deletions

View File

@@ -1,9 +1,11 @@
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
METHODS = ["full", "freeze", "lora"]
TRAINING_STAGES = {
@@ -14,79 +16,214 @@ TRAINING_STAGES = {
"Pre-Training": "pt"
}
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
"Phi1.5-1.3B": "microsoft/phi-1_5"
}
LAYERNORM_NAMES = {"norm", "ln"}
DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value",
"ChatGLM3": "query_key_value",
"Phi1.5": "Wqkv"
}
SUPPORTED_MODELS = OrderedDict()
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"Baichuan2": "baichuan2",
"InternLM": "intern",
"Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2",
"ChatGLM3": "chatglm3"
}
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
def register_model_group(
models: Dict[str, str],
module: Optional[str] = None,
template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
register_model_group(
models={
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
},
module="W_pack",
template="baichuan"
)
register_model_group(
models={
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
},
module="W_pack",
template="baichuan2"
)
register_model_group(
models={
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1"
},
module="query_key_value"
)
register_model_group(
models={
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
},
module="query_key_value"
)
register_model_group(
models={
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
},
template="bluelm"
)
register_model_group(
models={
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
},
module="query_key_value",
template="chatglm2"
)
register_model_group(
models={
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
},
module="query_key_value",
template="chatglm3"
)
register_model_group(
models={
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
},
template="llama2_zh"
)
register_model_group(
models={
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-180B": "tiiuae/falcon-180B",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
},
module="query_key_value",
template="falcon"
)
register_model_group(
models={
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
},
template="intern"
)
register_model_group(
models={
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b"
}
)
register_model_group(
models={
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
},
template="llama2"
)
register_model_group(
models={
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
},
template="mistral"
)
register_model_group(
models={
"Phi1.5-1.3B": "microsoft/phi-1_5"
},
module="Wqkv"
)
register_model_group(
models={
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
},
module="c_attn",
template="qwen"
)
register_model_group(
models={
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
}
)
register_model_group(
models={
"XVERSE-7B": "xverse/XVERSE-7B",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-65B": "xverse/XVERSE-65B",
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
},
template="xverse"
)
register_model_group(
models={
"Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B"
}
)