add test cases

Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
This commit is contained in:
hiyouga
2024-06-15 04:05:54 +08:00
parent f4f315fd11
commit 3ff9b87012
9 changed files with 184 additions and 34 deletions

View File

@@ -44,7 +44,10 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if model_args.infer_dtype == "auto":
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
else:
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]