add configurer
Former-commit-id: c40c9889615ffb49c7ce24c69c0d3d20d841c800
This commit is contained in:
@@ -9,8 +9,7 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
@@ -183,3 +182,12 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
Reference in New Issue
Block a user