Former-commit-id: 1a7ddd8c1d20dc251f53923bd0ab9f3f1031dd21
This commit is contained in:
hiyouga
2023-09-21 15:25:29 +08:00
parent 46a718f339
commit 95c0d9ab24
4 changed files with 30 additions and 14 deletions

View File

@@ -8,6 +8,14 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available()
except ImportError:
is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
@@ -197,7 +205,7 @@ def get_train_args(
# postprocess model_args
if training_args.bf16:
if not torch.cuda.is_bf16_supported():
if not is_bf16_available:
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
@@ -243,4 +251,12 @@ def get_infer_args(
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
# auto-detect cuda capability
if is_npu_available:
model_args.compute_dtype = torch.float16
elif is_bf16_available:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float16
return model_args, data_args, finetuning_args, generating_args