Former-commit-id: 2d8d47f6126d68db1701ed18fc31310c6f14dd49
This commit is contained in:
hiyouga
2024-06-20 22:56:05 +08:00
parent f16a4a8264
commit af2cb33bb2
4 changed files with 15 additions and 13 deletions

View File

@@ -91,8 +91,8 @@ def patch_config(
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. fsdp + qlora
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
# 2. quantization_bit is not None (qlora)
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True