update loader
Former-commit-id: 080d8eab858217ca58bffe719d5ffde7579c5bda
This commit is contained in:
@@ -8,7 +8,7 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from llmtuner.model.adapter import init_adapter
|
||||
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, register_autoclass
|
||||
from llmtuner.model.utils import load_valuehead_params, register_autoclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -92,10 +92,9 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
patch_model(model, tokenizer, model_args)
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
|
||||
Reference in New Issue
Block a user