[misc] fix new tokens adding (#7253)

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
flashJd
2025-04-21 23:19:02 +08:00
committed by GitHub
parent c5ba9106ec
commit 0ac641326b
2 changed files with 21 additions and 4 deletions

View File

@@ -55,14 +55,24 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
num_added_special_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens)))
if num_added_special_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.")
if model_args.new_normal_tokens is not None:
num_added_normal_tokens = tokenizer.add_tokens(
new_tokens=model_args.new_normal_tokens,
special_tokens=False,
)
logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens)))
if num_added_normal_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.")
def patch_processor(