[model] support yarn (#6693)
Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
@@ -53,10 +53,23 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
|
||||
def patch_processor(
|
||||
processor: "ProcessorMixin",
|
||||
|
||||
Reference in New Issue
Block a user