fix dispatch

Former-commit-id: deda82638716506dc690902c51276bb1eb0ddd5e
This commit is contained in:
hiyouga
2024-01-03 16:33:16 +08:00
parent 7168392a51
commit 8c74851b70
2 changed files with 7 additions and 4 deletions

View File

@@ -1,6 +1,7 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
@@ -8,7 +9,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
@@ -23,7 +24,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
if getattr(model, "quantization_method", None): # already set on current device
return model
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and getattr(model.config, "model_type", None) != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory