fix dispatch
Former-commit-id: deda82638716506dc690902c51276bb1eb0ddd5e
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
@@ -8,7 +9,7 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import get_current_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
@@ -23,7 +24,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and isinstance(model, PreTrainedModel)
|
||||
and getattr(model.config, "model_type", None) != "chatglm"
|
||||
):
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
|
||||
Reference in New Issue
Block a user