add max_memory for gptq #1923

Former-commit-id: 9afc42c8b999fbbc206d9a467ca5795b27a10096
This commit is contained in:
hiyouga
2023-12-20 18:15:17 +08:00
parent 2b1e52dcc9
commit dba1af4841
4 changed files with 26 additions and 24 deletions

View File

@@ -3,25 +3,22 @@ import os
import torch
from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser
from llmtuner.hparams import ModelArguments
@@ -68,12 +65,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> torch.device:
import accelerate
if accelerate.utils.is_xpu_available():
r"""
Gets the current available device.
"""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available():
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif torch.cuda.is_available():
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
@@ -117,7 +116,7 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
return
try:
from modelscope import snapshot_download # type: ignore
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,