Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec334f5891 | ||
|
|
885efe772e | ||
|
|
64fc9ba678 | ||
|
|
989eccd286 | ||
|
|
f0766a2ab0 | ||
|
|
178b85ff9a | ||
|
|
68dd1ef121 | ||
|
|
b222cffe98 | ||
|
|
b4f1ab93d1 | ||
|
|
f2e139f5cd |
@@ -57,7 +57,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
@@ -158,7 +158,7 @@ huggingface-cli login
|
|||||||
- Python 3.8+ and PyTorch 1.13.1+
|
- Python 3.8+ and PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
- sentencepiece, protobuf and tiktoken
|
- sentencepiece, protobuf and tiktoken
|
||||||
- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
|
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||||
- gradio and matplotlib (used in web UI)
|
- gradio and matplotlib (used in web UI)
|
||||||
- uvicorn, fastapi and sse-starlette (used in API)
|
- uvicorn, fastapi and sse-starlette (used in API)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
@@ -158,7 +158,7 @@ huggingface-cli login
|
|||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||||
- sentencepiece, protobuf 和 tiktoken
|
- sentencepiece, protobuf 和 tiktoken
|
||||||
- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
|
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
- gradio 和 matplotlib (用于网页端交互)
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ transformers>=4.31.0,<4.35.0
|
|||||||
datasets>=2.14.0
|
datasets>=2.14.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.6.0
|
peft>=0.6.0
|
||||||
trl==0.7.2
|
trl>=0.7.4
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
tiktoken
|
||||||
fire
|
|
||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
|
|||||||
190
src/evaluate.py
190
src/evaluate.py
@@ -1,190 +1,10 @@
|
|||||||
# coding=utf-8
|
from llmtuner import Evaluator
|
||||||
# Evaluates the performance of pre-trained models.
|
|
||||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
|
||||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
|
|
||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import transformers
|
|
||||||
from collections import Counter
|
|
||||||
from datasets import load_dataset
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
choices = ["A", "B", "C", "D"]
|
def main():
|
||||||
|
evaluator = Evaluator()
|
||||||
|
evaluator.eval()
|
||||||
@dataclass
|
|
||||||
class EvalTemplate:
|
|
||||||
|
|
||||||
system: str
|
|
||||||
choice: str
|
|
||||||
answer: str
|
|
||||||
prefix: str
|
|
||||||
|
|
||||||
def parse_example(
|
|
||||||
self,
|
|
||||||
example: Dict[str, str]
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
|
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
|
||||||
|
|
||||||
def format_example(
|
|
||||||
self,
|
|
||||||
target_data: Dict[str, str],
|
|
||||||
support_set: "Dataset",
|
|
||||||
subject_name: str,
|
|
||||||
use_history: bool
|
|
||||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
|
||||||
query, resp = self.parse_example(target_data)
|
|
||||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
|
||||||
|
|
||||||
if len(history):
|
|
||||||
temp = history.pop(0)
|
|
||||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
|
||||||
else:
|
|
||||||
query = self.system.format(subject=subject_name) + query
|
|
||||||
|
|
||||||
if not use_history:
|
|
||||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
|
||||||
history = []
|
|
||||||
return query.strip(), resp, history
|
|
||||||
|
|
||||||
|
|
||||||
eval_templates = {
|
|
||||||
"en": EvalTemplate(
|
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
|
||||||
choice="\n{choice}. {content}",
|
|
||||||
answer="\nAnswer: ",
|
|
||||||
prefix=" "
|
|
||||||
),
|
|
||||||
"zh": EvalTemplate(
|
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
|
||||||
choice="\n{choice}. {content}",
|
|
||||||
answer="\n答案:",
|
|
||||||
prefix="\n"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def batch_inference(
|
|
||||||
chat_model: ChatModel,
|
|
||||||
batch_input: Dict[str, torch.Tensor],
|
|
||||||
prefix_char: str
|
|
||||||
) -> List[str]:
|
|
||||||
logits = chat_model.model(**batch_input).logits
|
|
||||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
|
||||||
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
|
||||||
probs = torch.nn.functional.softmax(
|
|
||||||
torch.stack(
|
|
||||||
[
|
|
||||||
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
|
||||||
for choice in choices
|
|
||||||
],
|
|
||||||
dim=-1
|
|
||||||
),
|
|
||||||
dim=-1
|
|
||||||
).detach()
|
|
||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
model_name_or_path: str,
|
|
||||||
finetuning_type: Optional[str] = "lora",
|
|
||||||
checkpoint_dir: Optional[str] = None,
|
|
||||||
template: Optional[str] = "vanilla",
|
|
||||||
task: Optional[str] = "ceval",
|
|
||||||
dataset_dir: Optional[str] = "evaluation",
|
|
||||||
split: Optional[Literal["validation", "test"]] = "validation",
|
|
||||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
|
||||||
n_shot: Optional[int] = 5,
|
|
||||||
n_avg: Optional[int] = 1,
|
|
||||||
batch_size: Optional[int] = 4,
|
|
||||||
save_name: Optional[str] = None,
|
|
||||||
seed: Optional[int] = 42
|
|
||||||
):
|
|
||||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
|
||||||
|
|
||||||
transformers.set_seed(seed)
|
|
||||||
chat_model = ChatModel(dict(
|
|
||||||
model_name_or_path=model_name_or_path,
|
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
template=template
|
|
||||||
))
|
|
||||||
chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
|
||||||
eval_template = eval_templates[lang]
|
|
||||||
|
|
||||||
category_corrects: Dict[str, np.ndarray] = {
|
|
||||||
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
|
||||||
}
|
|
||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
|
||||||
results = {}
|
|
||||||
for subject in pbar:
|
|
||||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
|
||||||
labels, answers, all_outputs = [], [], []
|
|
||||||
for epoch in range(n_avg):
|
|
||||||
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
|
|
||||||
inputs, outputs = [], []
|
|
||||||
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
|
|
||||||
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
|
|
||||||
query, resp, history = eval_template.format_example(
|
|
||||||
target_data=dataset[split][i],
|
|
||||||
support_set=support_set,
|
|
||||||
subject_name=categorys[subject]["name"],
|
|
||||||
use_history=chat_model.template.use_history
|
|
||||||
)
|
|
||||||
input_ids, _ = chat_model.template.encode_oneturn(
|
|
||||||
tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
|
|
||||||
)
|
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
|
||||||
if epoch == 0:
|
|
||||||
labels.append(resp)
|
|
||||||
|
|
||||||
for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
|
|
||||||
batch_input = chat_model.tokenizer.pad(
|
|
||||||
inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
|
|
||||||
).to(chat_model.model.device)
|
|
||||||
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
|
|
||||||
outputs += preds
|
|
||||||
all_outputs.append(outputs)
|
|
||||||
|
|
||||||
for i in range(len(all_outputs[0])):
|
|
||||||
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
|
|
||||||
answers.append(count.most_common(1)[0][0])
|
|
||||||
|
|
||||||
corrects = (np.array(answers) == np.array(labels))
|
|
||||||
category_name = categorys[subject]["category"]
|
|
||||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
|
||||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
|
||||||
results[subject] = {str(i): answers[i] for i in range(len(answers))}
|
|
||||||
|
|
||||||
score_info = "\n".join([
|
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
|
||||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
|
||||||
])
|
|
||||||
|
|
||||||
print(score_info)
|
|
||||||
if save_name is not None:
|
|
||||||
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
|
|
||||||
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(score_info)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(evaluate)
|
main()
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
|
||||||
|
|
||||||
from llmtuner.api import create_app
|
from llmtuner.api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
|
from llmtuner.eval import Evaluator
|
||||||
from llmtuner.tuner import export_model, run_exp
|
from llmtuner.tuner import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.2.2"
|
||||||
|
|||||||
1
src/llmtuner/eval/__init__.py
Normal file
1
src/llmtuner/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.eval.engine import Evaluator
|
||||||
3
src/llmtuner/eval/constants.py
Normal file
3
src/llmtuner/eval/constants.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||||
110
src/llmtuner/eval/engine.py
Normal file
110
src/llmtuner/eval/engine.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import tiktoken
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from datasets import load_dataset
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||||
|
from llmtuner.eval.parser import get_eval_args
|
||||||
|
from llmtuner.eval.template import get_eval_template
|
||||||
|
from llmtuner.extras.misc import dispatch_model
|
||||||
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
|
||||||
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
|
self.model = dispatch_model(self.model)
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||||
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
|
self.choice_inputs = self._encode_choices()
|
||||||
|
|
||||||
|
def _encode_choices(self) -> List[int]:
|
||||||
|
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
|
||||||
|
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
logits = self.model(**batch_input).logits
|
||||||
|
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||||
|
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||||
|
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
||||||
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
|
def eval(self) -> None:
|
||||||
|
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
|
||||||
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
|
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||||
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
|
results = {}
|
||||||
|
for subject in pbar:
|
||||||
|
dataset = load_dataset(
|
||||||
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
|
name=subject,
|
||||||
|
download_mode="force_redownload"
|
||||||
|
)
|
||||||
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
|
inputs, outputs, labels = [], [], []
|
||||||
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
|
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
|
query, resp, history = self.eval_template.format_example(
|
||||||
|
target_data=dataset[self.data_args.split][i],
|
||||||
|
support_set=support_set,
|
||||||
|
subject_name=categorys[subject]["name"],
|
||||||
|
use_history=self.template.use_history
|
||||||
|
)
|
||||||
|
input_ids, _ = self.template.encode_oneturn(
|
||||||
|
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
||||||
|
)
|
||||||
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
|
labels.append(resp)
|
||||||
|
|
||||||
|
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
||||||
|
batch_input = self.tokenizer.pad(
|
||||||
|
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||||
|
).to(self.model.device)
|
||||||
|
preds = self.batch_inference(batch_input)
|
||||||
|
outputs += preds
|
||||||
|
|
||||||
|
corrects = (np.array(outputs) == np.array(labels))
|
||||||
|
category_name = categorys[subject]["category"]
|
||||||
|
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||||
|
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||||
|
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
|
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
|
score_info = "\n".join([
|
||||||
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
|
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
||||||
|
])
|
||||||
|
print(score_info)
|
||||||
|
if self.eval_args.save_dir is not None:
|
||||||
|
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||||
|
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
f.write(score_info)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
evaluator = Evaluator()
|
||||||
|
evaluator.eval()
|
||||||
49
src/llmtuner/eval/parser.py
Normal file
49
src/llmtuner/eval/parser.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import transformers
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from llmtuner.extras.misc import parse_args
|
||||||
|
from llmtuner.hparams import (
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
|
FinetuningArguments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_eval_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
|
FinetuningArguments
|
||||||
|
]:
|
||||||
|
parser = HfArgumentParser((
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
|
FinetuningArguments
|
||||||
|
))
|
||||||
|
return parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
|
FinetuningArguments
|
||||||
|
]:
|
||||||
|
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||||
|
|
||||||
|
if data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, eval_args, finetuning_args
|
||||||
86
src/llmtuner/eval/template.py
Normal file
86
src/llmtuner/eval/template.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llmtuner.eval.constants import CHOICES
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalTemplate:
|
||||||
|
|
||||||
|
system: str
|
||||||
|
choice: str
|
||||||
|
answer: str
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
def parse_example(
|
||||||
|
self,
|
||||||
|
example: Dict[str, str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
|
def format_example(
|
||||||
|
self,
|
||||||
|
target_data: Dict[str, str],
|
||||||
|
support_set: "Dataset",
|
||||||
|
subject_name: str,
|
||||||
|
use_history: bool
|
||||||
|
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||||
|
query, resp = self.parse_example(target_data)
|
||||||
|
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||||
|
|
||||||
|
if len(history):
|
||||||
|
temp = history.pop(0)
|
||||||
|
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||||
|
else:
|
||||||
|
query = self.system.format(subject=subject_name) + query
|
||||||
|
|
||||||
|
if not use_history:
|
||||||
|
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||||
|
history = []
|
||||||
|
return query.strip(), resp, history
|
||||||
|
|
||||||
|
|
||||||
|
eval_templates: Dict[str, EvalTemplate] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_eval_template(
|
||||||
|
name: str,
|
||||||
|
system: str,
|
||||||
|
choice: str,
|
||||||
|
answer: str,
|
||||||
|
prefix: str
|
||||||
|
) -> None:
|
||||||
|
eval_templates[name] = EvalTemplate(
|
||||||
|
system=system,
|
||||||
|
choice=choice,
|
||||||
|
answer=answer,
|
||||||
|
prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_template(name: str) -> EvalTemplate:
|
||||||
|
eval_template = eval_templates.get(name, None)
|
||||||
|
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||||
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
|
register_eval_template(
|
||||||
|
name="en",
|
||||||
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
|
choice="\n{choice}. {content}",
|
||||||
|
answer="\nAnswer: ",
|
||||||
|
prefix=" "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_eval_template(
|
||||||
|
name="zh",
|
||||||
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
|
choice="\n{choice}. {content}",
|
||||||
|
answer="\n答案:",
|
||||||
|
prefix="\n"
|
||||||
|
)
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
|
from collections import defaultdict, OrderedDict
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
|
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
@@ -14,79 +16,222 @@ TRAINING_STAGES = {
|
|||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt"
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_MODELS = {
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
|
||||||
"LLaMA-13B": "huggyllama/llama-13b",
|
|
||||||
"LLaMA-30B": "huggyllama/llama-30b",
|
|
||||||
"LLaMA-65B": "huggyllama/llama-65b",
|
|
||||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
|
||||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
|
||||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
|
||||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
|
||||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
|
||||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
|
||||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
|
||||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
|
||||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
|
||||||
"BLOOM-3B": "bigscience/bloom-3b",
|
|
||||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
|
||||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
|
||||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
|
||||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
|
||||||
"Falcon-7B": "tiiuae/falcon-7b",
|
|
||||||
"Falcon-40B": "tiiuae/falcon-40b",
|
|
||||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
|
||||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
|
||||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
|
||||||
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
|
|
||||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
|
|
||||||
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
|
|
||||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
|
||||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
|
|
||||||
"InternLM-7B": "internlm/internlm-7b",
|
|
||||||
"InternLM-20B": "internlm/internlm-20b",
|
|
||||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
|
||||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
|
|
||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
|
||||||
"Qwen-14B": "Qwen/Qwen-14B",
|
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
|
||||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
|
||||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
|
|
||||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
|
|
||||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
|
||||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
|
|
||||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_MODULE = {
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
"LLaMA": "q_proj,v_proj",
|
|
||||||
"LLaMA2": "q_proj,v_proj",
|
|
||||||
"ChineseLLaMA2": "q_proj,v_proj",
|
|
||||||
"BLOOM": "query_key_value",
|
|
||||||
"BLOOMZ": "query_key_value",
|
|
||||||
"Falcon": "query_key_value",
|
|
||||||
"Baichuan": "W_pack",
|
|
||||||
"Baichuan2": "W_pack",
|
|
||||||
"InternLM": "q_proj,v_proj",
|
|
||||||
"Qwen": "c_attn",
|
|
||||||
"XVERSE": "q_proj,v_proj",
|
|
||||||
"ChatGLM2": "query_key_value",
|
|
||||||
"ChatGLM3": "query_key_value",
|
|
||||||
"Phi1.5": "Wqkv"
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = {
|
DEFAULT_MODULE = defaultdict(str)
|
||||||
"LLaMA2": "llama2",
|
|
||||||
"ChineseLLaMA2": "llama2_zh",
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
"Baichuan": "baichuan",
|
|
||||||
"Baichuan2": "baichuan2",
|
|
||||||
"InternLM": "intern",
|
def register_model_group(
|
||||||
"Qwen": "chatml",
|
models: Dict[str, str],
|
||||||
"XVERSE": "xverse",
|
module: Optional[str] = None,
|
||||||
"ChatGLM2": "chatglm2",
|
template: Optional[str] = None
|
||||||
"ChatGLM3": "chatglm3"
|
) -> None:
|
||||||
}
|
prefix = None
|
||||||
|
for name, path in models.items():
|
||||||
|
if prefix is None:
|
||||||
|
prefix = name.split("-")[0]
|
||||||
|
else:
|
||||||
|
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||||
|
SUPPORTED_MODELS[name] = path
|
||||||
|
if module is not None:
|
||||||
|
DEFAULT_MODULE[prefix] = module
|
||||||
|
if template is not None:
|
||||||
|
DEFAULT_TEMPLATE[prefix] = template
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||||
|
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||||
|
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||||
|
},
|
||||||
|
module="W_pack",
|
||||||
|
template="baichuan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||||
|
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||||
|
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||||
|
},
|
||||||
|
module="W_pack",
|
||||||
|
template="baichuan2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
|
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||||
|
},
|
||||||
|
module="query_key_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||||
|
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||||
|
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||||
|
},
|
||||||
|
module="query_key_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||||
|
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||||
|
},
|
||||||
|
template="bluelm"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="chatglm2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||||
|
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="chatglm3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||||
|
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
|
||||||
|
},
|
||||||
|
template="llama2_zh"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Falcon-7B": "tiiuae/falcon-7b",
|
||||||
|
"Falcon-40B": "tiiuae/falcon-40b",
|
||||||
|
"Falcon-180B": "tiiuae/falcon-180B",
|
||||||
|
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||||
|
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||||
|
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="falcon"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"InternLM-7B": "internlm/internlm-7b",
|
||||||
|
"InternLM-20B": "internlm/internlm-20b",
|
||||||
|
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||||
|
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||||
|
},
|
||||||
|
template="intern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||||
|
},
|
||||||
|
module="qkv_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
|
"LLaMA-13B": "huggyllama/llama-13b",
|
||||||
|
"LLaMA-30B": "huggyllama/llama-30b",
|
||||||
|
"LLaMA-65B": "huggyllama/llama-65b"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||||
|
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||||
|
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||||
|
},
|
||||||
|
template="llama2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||||
|
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
},
|
||||||
|
template="mistral"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||||
|
},
|
||||||
|
module="Wqkv"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
|
"Qwen-14B": "Qwen/Qwen-14B",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||||
|
},
|
||||||
|
module="c_attn",
|
||||||
|
template="qwen"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||||
|
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||||
|
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||||
|
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||||
|
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||||
|
},
|
||||||
|
template="xverse"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Yi-6B": "01-ai/Yi-6B",
|
||||||
|
"Yi-34B": "01-ai/Yi-34B"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -17,6 +19,7 @@ except ImportError:
|
|||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +77,7 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
r"""
|
r"""
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
"""
|
"""
|
||||||
@@ -93,6 +96,17 @@ def torch_gc() -> None:
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
|
if args is not None:
|
||||||
|
return parser.parse_dict(args)
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
return parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ from typing import Optional, Tuple
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
is_flash_attn_2_available = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||||
|
is_flash_attn_2_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
|
is_flash_attn_2_available = False
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|||||||
@@ -447,6 +447,25 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="falcon",
|
||||||
|
prefix=[
|
||||||
|
"{{system}}"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"User: {{query}}\nFalcon:"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
efficient_eos=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||||
https://huggingface.co/internlm/internlm-chat-20b
|
https://huggingface.co/internlm/internlm-chat-20b
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: Optional[str] = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "The name of the folder containing datasets."}
|
metadata={"help": "Path to the folder containing the datasets."}
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: Optional[str] = field(
|
||||||
default="train",
|
default="train",
|
||||||
|
|||||||
55
src/llmtuner/hparams/evaluation_args.py
Normal file
55
src/llmtuner/hparams/evaluation_args.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from datasets import DownloadMode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to specify the evaluation parameters.
|
||||||
|
"""
|
||||||
|
task: str = field(
|
||||||
|
metadata={"help": "Name of the evaluation task."}
|
||||||
|
)
|
||||||
|
task_dir: Optional[str] = field(
|
||||||
|
default="evaluation",
|
||||||
|
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||||
|
)
|
||||||
|
batch_size: Optional[int] = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "The batch size per GPU for evaluation."}
|
||||||
|
)
|
||||||
|
seed: Optional[int] = field(
|
||||||
|
default=42,
|
||||||
|
metadata={"help": "Random seed to be used with data loaders."}
|
||||||
|
)
|
||||||
|
lang: Optional[Literal["en", "zh"]] = field(
|
||||||
|
default="en",
|
||||||
|
metadata={"help": "Language used at evaluation."}
|
||||||
|
)
|
||||||
|
n_shot: Optional[int] = field(
|
||||||
|
default=5,
|
||||||
|
metadata={"help": "Number of examplars for few-shot learning."}
|
||||||
|
)
|
||||||
|
save_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to save the evaluation results."}
|
||||||
|
)
|
||||||
|
download_mode: Optional[DownloadMode] = field(
|
||||||
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
|
metadata={"help": "Download mode used for the evaluation datasets."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
task_available = []
|
||||||
|
for folder in os.listdir(self.task_dir):
|
||||||
|
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
||||||
|
task_available.append(folder)
|
||||||
|
|
||||||
|
if self.task not in task_available:
|
||||||
|
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
||||||
|
|
||||||
|
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||||
|
raise ValueError("`save_dir` already exists, use another one.")
|
||||||
@@ -12,7 +12,7 @@ class FinetuningArguments:
|
|||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
)
|
)
|
||||||
@@ -45,7 +45,7 @@ class FinetuningArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||||
|
|||||||
@@ -54,11 +54,11 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,12 +38,13 @@ def init_adapter(
|
|||||||
|
|
||||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||||
|
return model
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze":
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = getattr(model.config, "num_layers")
|
num_layers = getattr(model.config, "num_layers")
|
||||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transform
|
|||||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||||
require_version("trl==0.7.2", "To fix: pip install trl==0.7.2")
|
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
@@ -123,9 +123,12 @@ def load_model_and_tokenizer(
|
|||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
if LlamaPatches.is_flash_attn_2_available:
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import torch
|
import torch
|
||||||
import datasets
|
import datasets
|
||||||
import transformers
|
import transformers
|
||||||
@@ -8,6 +7,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
|||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.misc import parse_args
|
||||||
from llmtuner.hparams import (
|
from llmtuner.hparams import (
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
DataArguments,
|
DataArguments,
|
||||||
@@ -19,17 +19,6 @@ from llmtuner.hparams import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
|
||||||
if args is not None:
|
|
||||||
return parser.parse_dict(args)
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
|
||||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
return parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_train_args(
|
def parse_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@@ -46,7 +35,7 @@ def parse_train_args(
|
|||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments
|
GeneratingArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(
|
def parse_infer_args(
|
||||||
@@ -63,7 +52,7 @@ def parse_infer_args(
|
|||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments
|
GeneratingArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def get_train_args(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@@ -56,7 +56,7 @@ def prepare_model_for_training(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Includes:
|
Includes:
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import deepspeed # type: ignore
|
|
||||||
from copy import deepcopy
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
@@ -11,7 +9,6 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import PreTrainedModelWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
@@ -50,36 +47,6 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
|
|
||||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
|
||||||
if model is not None:
|
|
||||||
if hasattr(model, "config"):
|
|
||||||
hidden_size = (
|
|
||||||
max(model.config.hidden_sizes)
|
|
||||||
if getattr(model.config, "hidden_sizes", None)
|
|
||||||
else getattr(model.config, "hidden_size", None)
|
|
||||||
)
|
|
||||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
|
||||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
|
||||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
|
||||||
config_kwargs.update(
|
|
||||||
{
|
|
||||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
|
||||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
|
||||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
|
||||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
|
||||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
|
||||||
config_kwargs["zero_optimization"]["stage"] = 0
|
|
||||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self,
|
||||||
model: Optional[torch.nn.Module] = None,
|
model: Optional[torch.nn.Module] = None,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
@@ -108,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
queries, responses = self.get_inputs(batch)
|
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
queries, responses, rewards = [], [], []
|
||||||
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
|
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||||
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||||
|
queries.extend(mini_batch_queries)
|
||||||
|
responses.extend(mini_batch_responses)
|
||||||
|
rewards.extend(mini_batch_rewards)
|
||||||
|
|
||||||
# Cast to training mode
|
# Cast to training mode
|
||||||
unwrapped_model.gradient_checkpointing_enable()
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
@@ -165,7 +170,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
@@ -219,14 +224,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
@PPODecorators.empty_cuda_cache()
|
@PPODecorators.empty_device_cache()
|
||||||
def batched_forward_pass(
|
def batched_forward_pass(
|
||||||
self,
|
self,
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def run_ppo(
|
|||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm,
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
optimize_cuda_cache=True,
|
optimize_device_cache=True,
|
||||||
target=finetuning_args.ppo_target,
|
target=finetuning_args.ppo_target,
|
||||||
log_with=finetuning_args.ppo_logger,
|
log_with=finetuning_args.ppo_logger,
|
||||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
|
|||||||
@@ -61,13 +61,17 @@ def get_model_path(model_name: str) -> str:
|
|||||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prefix(model_name: str) -> str:
|
||||||
|
return model_name.split("-")[0]
|
||||||
|
|
||||||
|
|
||||||
def get_module(model_name: str) -> str:
|
def get_module(model_name: str) -> str:
|
||||||
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
|
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
|
||||||
|
|
||||||
|
|
||||||
def get_template(model_name: str) -> str:
|
def get_template(model_name: str) -> str:
|
||||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class Runner:
|
|||||||
args["upcast_layernorm"] = True
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = get("train.reward_model")
|
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
||||||
|
|
||||||
if args["stage"] == "dpo":
|
if args["stage"] == "dpo":
|
||||||
args["dpo_beta"] = get("train.dpo_beta")
|
args["dpo_beta"] = get("train.dpo_beta")
|
||||||
|
|||||||
Reference in New Issue
Block a user