update patcher

Former-commit-id: d6d7b6670847ce4ea10353c5b126214542b45c2b
This commit is contained in:
hiyouga
2023-12-23 15:24:27 +08:00
parent f869e44fe5
commit 940403720a
6 changed files with 135 additions and 130 deletions

View File

@@ -8,9 +8,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from llmtuner.model.utils import (
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
)
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -94,10 +92,8 @@ def load_model_and_tokenizer(
)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
patch_model(model)
patch_model(model, tokenizer, model_args)
register_autoclass(config, model, tokenizer)
if not is_deepspeed_zero3_enabled():
resize_embedding_layer(model, tokenizer)
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)