Former-commit-id: 49975755d47344e362145c52548fdda8783f2c0c
This commit is contained in:
hiyouga
2023-10-20 23:28:52 +08:00
parent 1cb9a38ac2
commit d602f06882
5 changed files with 44 additions and 48 deletions

View File

@@ -84,10 +84,12 @@ def batch_inference(
prefix_char: str
) -> List[str]:
logits = chat_model.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
for choice in choices
],
dim=-1
@@ -120,8 +122,8 @@ def evaluate(
checkpoint_dir=checkpoint_dir,
template=template
))
chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2
eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@@ -289,8 +289,8 @@ register_template(
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
"""
register_template(
name="llama2_zh",
@@ -307,7 +307,6 @@ register_template(
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
@@ -328,8 +327,8 @@ register_template(
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-v1.5
"""
register_template(
name="vicuna",
@@ -365,44 +364,9 @@ register_template(
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
"""
register_template(
name="ziya",
@@ -424,6 +388,8 @@ register_template(
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
"""
register_template(
name="aquila",
@@ -449,6 +415,7 @@ register_template(
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
"""
register_template(
name="intern",
@@ -542,6 +509,7 @@ register_template(
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
"""
register_template(
name="chatml",
@@ -591,7 +559,29 @@ register_template(
r"""
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
Supports: https://huggingface.co/openchat/openchat_v3.2_super
"""
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 User: {{query}}",
{"token": "<|end_of_turn|>"},
"GPT4 Assistant: "
],
system="",
sep=[
{"token": "<|end_of_turn|>"}
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template(
name="xverse",

View File

@@ -113,6 +113,8 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
except Exception:
if self.dataset is not None:
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]