format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,10 +1,11 @@
import os
import json
import os
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
@@ -12,14 +13,13 @@ from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
@@ -28,12 +28,11 @@ class FixValueHeadModelCallback(TrainerCallback):
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
@@ -99,7 +98,9 @@ class LogCallback(TrainerCallback):
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r"""
Event called after a successful prediction.
"""
@@ -125,18 +126,22 @@ class LogCallback(TrainerCallback):
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
remaining_time=self.remaining_time,
)
if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
))
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from enum import Enum
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
@@ -11,14 +11,7 @@ DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text"
}
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
IGNORE_INDEX = -100
@@ -39,22 +32,21 @@ TRAINING_STAGES = {
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"Pre-Training": "pt"
"Pre-Training": "pt",
}
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class DownloadSource(str, Enum):
DEFAULT = "hf"
MODELSCOPE = "ms"
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
@@ -73,19 +65,19 @@ register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
}
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
module="W_pack",
template="baichuan"
template="baichuan",
)
@@ -93,23 +85,23 @@ register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
}
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
},
},
module="W_pack",
template="baichuan2"
template="baichuan2",
)
@@ -117,18 +109,18 @@ register_model_group(
models={
"BLOOM-560M": {
DownloadSource.DEFAULT: "bigscience/bloom-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
},
"BLOOM-3B": {
DownloadSource.DEFAULT: "bigscience/bloom-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
},
"BLOOM-7B1": {
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
module="query_key_value"
module="query_key_value",
)
@@ -136,18 +128,18 @@ register_model_group(
models={
"BLOOMZ-560M": {
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
},
"BLOOMZ-3B": {
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
},
"BLOOMZ-7B1-mt": {
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
module="query_key_value"
module="query_key_value",
)
@@ -155,14 +147,14 @@ register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
}
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm"
template="bluelm",
)
@@ -170,11 +162,11 @@ register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
module="query_key_value",
template="chatglm2"
template="chatglm2",
)
@@ -182,15 +174,15 @@ register_model_group(
models={
"ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
},
"ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
}
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
module="query_key_value",
template="chatglm3"
template="chatglm3",
)
@@ -198,30 +190,30 @@ register_model_group(
models={
"ChineseLLaMA2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
},
"ChineseLLaMA2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
},
"ChineseLLaMA2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
},
"ChineseLLaMA2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
},
"ChineseLLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
},
"ChineseLLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
},
},
template="llama2_zh"
template="llama2_zh",
)
@@ -229,22 +221,22 @@ register_model_group(
models={
"DeepSeekLLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
},
"DeepSeekLLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
},
"DeepSeekLLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
},
"DeepSeekLLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
},
},
template="deepseek"
template="deepseek",
)
@@ -252,22 +244,22 @@ register_model_group(
models={
"DeepSeekCoder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
},
"DeepSeekCoder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
},
"DeepSeekCoder-6.7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
},
"DeepSeekCoder-33B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
},
},
template="deepseekcoder"
template="deepseekcoder",
)
@@ -275,14 +267,14 @@ register_model_group(
models={
"DeepSeekMoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeekMoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
},
template="deepseek"
template="deepseek",
)
@@ -290,31 +282,31 @@ register_model_group(
models={
"Falcon-7B": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
},
"Falcon-180B": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
},
"Falcon-7B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
},
"Falcon-40B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
},
"Falcon-180B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
}
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
module="query_key_value",
template="falcon"
template="falcon",
)
@@ -322,22 +314,22 @@ register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
}
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern"
template="intern",
)
@@ -345,23 +337,23 @@ register_model_group(
models={
"InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
},
"InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
},
"InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
},
"InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b"
}
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
},
module="wqkv",
template="intern2"
template="intern2",
)
@@ -369,31 +361,28 @@ register_model_group(
models={
"LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
module="qkv_proj"
module="qkv_proj",
)
register_model_group(
models={
"LLaMA-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
},
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
"LLaMA-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
},
"LLaMA-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
},
"LLaMA-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
}
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
},
}
)
@@ -402,30 +391,30 @@ register_model_group(
models={
"LLaMA2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
},
"LLaMA2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
},
"LLaMA2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
},
"LLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
},
"LLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
},
"LLaMA2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
}
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
},
},
template="llama2"
template="llama2",
)
@@ -433,18 +422,18 @@ register_model_group(
models={
"Mistral-7B": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
},
template="mistral"
template="mistral",
)
@@ -452,14 +441,14 @@ register_model_group(
models={
"Mixtral-8x7B": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
},
template="mistral"
template="mistral",
)
@@ -467,110 +456,87 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat_3.5",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
}
},
template="openchat"
template="openchat",
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"
}
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
}
)
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
},
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
}
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
},
},
module="c_attn",
template="qwen"
template="qwen",
)
register_model_group(
models={
"SOLAR-10.7B": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"
},
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
"SOLAR-10.7B-Chat": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar"
template="solar",
)
@@ -578,7 +544,7 @@ register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
@@ -588,68 +554,51 @@ register_model_group(
models={
"Vicuna1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
},
"Vicuna1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
}
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
},
},
template="vicuna"
template="vicuna",
)
register_model_group(
models={
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
}
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
},
template="xuanyuan"
template="xuanyuan",
)
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
},
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2"
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat"
}
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
},
template="xverse"
template="xverse",
)
@@ -657,45 +606,33 @@ register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi"
template="yayi",
)
register_model_group(
models={
"Yi-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
},
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
},
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
"Yi-6B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
},
"Yi-34B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
}
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
},
},
template="yi"
template="yi",
)
@@ -703,18 +640,18 @@ register_model_group(
models={
"Yuan2-2B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf"
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
},
"Yuan2-51B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf"
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
},
"Yuan2-102B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf"
}
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
},
},
template="yuan"
template="yuan",
)
@@ -722,12 +659,12 @@ register_model_group(
models={
"Zephyr-7B-Alpha-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
},
"Zephyr-7B-Beta-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
}
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
},
template="zephyr"
template="zephyr",
)

View File

@@ -1,5 +1,5 @@
import sys
import logging
import sys
class LoggerHandler(logging.Handler):
@@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)

View File

@@ -1,31 +1,33 @@
import gc
import os
import torch
from typing import TYPE_CHECKING, Dict, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import (
WEIGHTS_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available
is_torch_xpu_available,
)
from peft import PeftModel
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
except:
except Exception:
_is_bf16_available = False
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
@@ -36,6 +38,7 @@ class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
@@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead",
output_dir: str,
safe_serialization: bool
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
@@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
@@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir,
state_dict=decoder_state_dict or None,
safe_serialization=safe_serialization
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
@@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
try:
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@@ -9,7 +9,7 @@ def is_package_available(name: str) -> bool:
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
except Exception:
return "0.0.0"

View File

@@ -1,11 +1,16 @@
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import (
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv
Cache,
LlamaAttention,
LlamaFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
@@ -19,7 +24,7 @@ def llama_torch_attn_forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -45,15 +50,17 @@ def llama_torch_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@@ -68,14 +75,17 @@ def llama_torch_attn_forward(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
@@ -94,7 +104,7 @@ def llama_flash_attn_forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
@@ -124,9 +134,9 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0
@@ -144,14 +154,16 @@ def llama_flash_attn_forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@@ -162,11 +174,14 @@ def llama_flash_attn_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)

View File

@@ -1,12 +1,14 @@
import os
import math
import json
import math
import os
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger
from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
@@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
"""
last = scalars[0]
smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
@@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)