Former-commit-id: f49adc4ab5eade21d7a9e029212f17688ee9b0cf
This commit is contained in:
hiyouga
2024-06-24 22:34:31 +08:00
parent abcb94a738
commit a79e93f335
6 changed files with 32 additions and 9 deletions

View File

@@ -58,10 +58,10 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
if model_args.infer_dtype == "auto":
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
else:
if model_args.infer_dtype != "auto" and not is_trainable:
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]