|
|
|
|
@@ -9,10 +9,12 @@ from transformers.utils.versions import require_version
|
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
|
|
|
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
|
|
|
|
is_fp16_available = is_torch_cuda_available()
|
|
|
|
|
is_bf16_available = is_torch_bf16_gpu_available()
|
|
|
|
|
is_npu_available = is_torch_npu_available()
|
|
|
|
|
except ImportError:
|
|
|
|
|
is_fp16_available = torch.cuda.is_available()
|
|
|
|
|
is_bf16_available = torch.cuda.is_bf16_supported()
|
|
|
|
|
is_npu_available = False
|
|
|
|
|
|
|
|
|
|
@@ -29,6 +31,17 @@ from llmtuner.hparams import (
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _infer_dtype() -> torch.dtype:
|
|
|
|
|
if is_npu_available:
|
|
|
|
|
return torch.float16
|
|
|
|
|
elif is_bf16_available:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
elif is_fp16_available:
|
|
|
|
|
return torch.float16
|
|
|
|
|
else:
|
|
|
|
|
return torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
|
|
|
|
if args is not None:
|
|
|
|
|
return parser.parse_dict(args)
|
|
|
|
|
@@ -211,7 +224,7 @@ def get_train_args(
|
|
|
|
|
elif training_args.fp16:
|
|
|
|
|
model_args.compute_dtype = torch.float16
|
|
|
|
|
else:
|
|
|
|
|
model_args.compute_dtype = torch.float32
|
|
|
|
|
model_args.compute_dtype = _infer_dtype()
|
|
|
|
|
|
|
|
|
|
model_args.model_max_length = data_args.cutoff_len
|
|
|
|
|
|
|
|
|
|
@@ -252,11 +265,6 @@ def get_infer_args(
|
|
|
|
|
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
|
|
|
|
|
|
|
|
|
# auto-detect cuda capability
|
|
|
|
|
if is_npu_available:
|
|
|
|
|
model_args.compute_dtype = torch.float16
|
|
|
|
|
elif is_bf16_available:
|
|
|
|
|
model_args.compute_dtype = torch.bfloat16
|
|
|
|
|
else:
|
|
|
|
|
model_args.compute_dtype = torch.float16
|
|
|
|
|
model_args.compute_dtype = _infer_dtype()
|
|
|
|
|
|
|
|
|
|
return model_args, data_args, finetuning_args, generating_args
|
|
|
|
|
|