fix qwen inference

Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
This commit is contained in:
hiyouga
2023-08-03 16:31:55 +08:00
parent e434348216
commit 27f4317ec6
2 changed files with 7 additions and 7 deletions

View File

@@ -17,9 +17,7 @@ class ChatModel:
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.stop_ids = [
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model