fix add tokens

Former-commit-id: ff5353681a87d033903bf8cf6133c6bdb3fa9e5a
This commit is contained in:
hiyouga
2024-03-06 15:04:02 +08:00
parent 73d9dfc7ab
commit 67f02f75d0
3 changed files with 7 additions and 6 deletions

View File

@@ -264,15 +264,14 @@ def _register_template(
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
is_oov = eos_token not in tokenizer.get_vocab()
tokenizer.add_special_tokens({"eos_token": eos_token})
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if is_oov:
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
@@ -368,10 +367,12 @@ def get_template_and_fix_tokenizer(
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
tokenizer.add_special_tokens(
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)