support resize embeddings #1786

Former-commit-id: 368a41bd3c6a04f869083058d9165954fbdad105
This commit is contained in:
hiyouga
2023-12-11 17:50:02 +08:00
parent 7a03c8dab5
commit 95c561983c
2 changed files with 16 additions and 1 deletions

View File

@@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -185,6 +185,9 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Resize token embeddings
resize_embedding_layer(model, tokenizer)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)