support new special token #3420

Former-commit-id: f5c6a47f5193ab3a6c137580992bdcce0b31fdd5
This commit is contained in:
hiyouga
2024-04-24 23:39:31 +08:00
parent 12f852b8d4
commit 83404c4fa9
8 changed files with 47 additions and 7 deletions

View File

@@ -39,6 +39,8 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
r"""
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
@@ -57,6 +59,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
return tokenizer