tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
|
||||
@@ -19,7 +19,7 @@ class ChatModel:
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
@@ -52,7 +52,7 @@ class ChatModel:
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor(),
|
||||
stopping_criteria=get_stopwords_criteria(self.stop_ids)
|
||||
stopping_criteria=get_stopping_criteria(self.stop_ids)
|
||||
))
|
||||
|
||||
if max_length:
|
||||
|
||||
Reference in New Issue
Block a user