support Qwen-7B, fix InternLM-7B inference

Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
This commit is contained in:
hiyouga
2023-08-03 15:53:32 +08:00
parent da08fa7c63
commit 2e19afedb8
8 changed files with 89 additions and 25 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@@ -16,6 +16,10 @@ class ChatModel:
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.stop_ids = [
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
def process_args(
self,
@@ -47,7 +51,8 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
logits_processor=get_logits_processor(),
stopping_criteria=get_stopwords_criteria(self.stop_ids)
))
if max_length: