[model] support yarn (#6693)

Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
hoshi-hiyouga
2025-01-18 13:56:09 +08:00
committed by GitHub
parent e4046bdd1f
commit 87d685b59f
11 changed files with 84 additions and 64 deletions

View File

@@ -86,20 +86,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)