support Qwen-7B, fix InternLM-7B inference

Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
This commit is contained in:
hiyouga
2023-08-03 15:53:32 +08:00
parent da08fa7c63
commit 2e19afedb8
8 changed files with 89 additions and 25 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@@ -16,6 +16,10 @@ class ChatModel:
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.stop_ids = [
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
def process_args(
self,
@@ -47,7 +51,8 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
logits_processor=get_logits_processor(),
stopping_criteria=get_stopwords_criteria(self.stop_ids)
))
if max_length:

View File

@@ -1,8 +1,7 @@
import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
from llmtuner.extras.constants import LAYERNORM_NAMES
@@ -46,6 +45,22 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_ids: List[int]) -> None:
super().__init__()
self.stop_ids = stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopwords_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopwords_criteria = StoppingCriteriaList()
stopwords_criteria.append(StopWordsCriteria(stop_ids))
return stopwords_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.

View File

@@ -9,6 +9,7 @@ class Template:
prompt: str
sep: str
use_history: bool
stop_words: List[str]
def get_prompt(
self,
@@ -74,13 +75,16 @@ class Llama2Template(Template):
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
def register_template(
name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str]
) -> None:
template_class = Llama2Template if name == "llama2" else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history
use_history=use_history,
stop_words=stop_words
)
@@ -98,7 +102,8 @@ register_template(
prefix="",
prompt="{query}",
sep="",
use_history=False
use_history=False,
stop_words=[]
)
@@ -111,7 +116,8 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -132,7 +138,8 @@ register_template(
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt="[INST] {query} [/INST] ",
sep="<s>",
use_history=True
use_history=True,
stop_words=[]
)
@@ -146,7 +153,8 @@ register_template(
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -160,7 +168,8 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="",
use_history=True
use_history=True,
stop_words=[]
)
@@ -172,7 +181,8 @@ register_template(
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -184,7 +194,8 @@ register_template(
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -196,7 +207,8 @@ register_template(
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -208,7 +220,8 @@ register_template(
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
use_history=True,
stop_words=[]
)
@@ -221,7 +234,8 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
use_history=True,
stop_words=[]
)
@@ -233,7 +247,8 @@ register_template(
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
use_history=True,
stop_words=["<eoa>"]
)
@@ -245,7 +260,8 @@ register_template(
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True
use_history=True,
stop_words=[]
)
@@ -258,5 +274,19 @@ register_template(
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
use_history=True,
stop_words=["<|end|>"]
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
"""
register_template(
name="chatml",
prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
sep="<|im_end|>\n",
use_history=True,
stop_words=["<|im_end|>"]
)

View File

@@ -19,7 +19,8 @@ class FinetuningArguments:
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\", \"40\"]"}
Baichuan choices: [\"32\", \"40\"] \
Qwen choices: [\"32\"]"}
)
num_layer_trainable: Optional[int] = field(
default=3,
@@ -30,7 +31,8 @@ class FinetuningArguments:
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"attn\", \"mlp\"]"}
)
lora_rank: Optional[int] = field(
default=8,
@@ -47,9 +49,10 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
LLaMA & LLaMA-2 & InternLM choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"c_proj\", \"w1\", \"w2\"]"}
)
def __post_init__(self):

View File

@@ -67,7 +67,7 @@ def load_model_and_tokenizer(
**config_kwargs
)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True