fix tokenizer #417

Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
This commit is contained in:
hiyouga
2023-08-08 23:59:41 +08:00
parent 805478c911
commit 6e27a9e39a
4 changed files with 24 additions and 17 deletions

View File

@@ -5,7 +5,7 @@ from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@@ -16,7 +16,7 @@ class ChatModel:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode
self.template = get_template(data_args.template)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))