add max_memory for gptq #1923
Former-commit-id: 9afc42c8b999fbbc206d9a467ca5795b27a10096
This commit is contained in:
@@ -3,25 +3,22 @@ import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available
|
||||
)
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
try:
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
except:
|
||||
_is_bf16_available = False
|
||||
except:
|
||||
_is_bf16_available = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
@@ -68,12 +65,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
import accelerate
|
||||
if accelerate.utils.is_xpu_available():
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif accelerate.utils.is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif torch.cuda.is_available():
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
device = "cpu"
|
||||
@@ -117,7 +116,7 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
return
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from modelscope import snapshot_download
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
|
||||
Reference in New Issue
Block a user