add test cases
Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
This commit is contained in:
@@ -44,7 +44,10 @@ def patch_config(
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
if model_args.infer_dtype == "auto":
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
else:
|
||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
|
||||
Reference in New Issue
Block a user