add configurer

Former-commit-id: c40c9889615ffb49c7ce24c69c0d3d20d841c800
This commit is contained in:
hiyouga
2023-12-15 21:46:40 +08:00
parent f902b0d420
commit 0409428d87
3 changed files with 83 additions and 77 deletions

View File

@@ -9,8 +9,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import DataArguments
@@ -183,3 +182,12 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
new_embedding_size = model.get_input_embeddings().weight.size(0)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()