tiny fix
Former-commit-id: 1a7ddd8c1d20dc251f53923bd0ab9f3f1031dd21
This commit is contained in:
@@ -8,6 +8,14 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
try:
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
is_bf16_available = is_torch_bf16_gpu_available()
|
||||
is_npu_available = is_torch_npu_available()
|
||||
except ImportError:
|
||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
is_npu_available = False
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
@@ -197,7 +205,7 @@ def get_train_args(
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
elif training_args.fp16:
|
||||
@@ -243,4 +251,12 @@ def get_infer_args(
|
||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
# auto-detect cuda capability
|
||||
if is_npu_available:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif is_bf16_available:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
Reference in New Issue
Block a user