Former-commit-id: 3c2c45812a720d92f7f5b15b9f03370fe6bf069e
This commit is contained in:
hiyouga
2024-06-17 18:17:48 +08:00
parent 485a80d294
commit 60d9896a70
3 changed files with 25 additions and 14 deletions

View File

@@ -89,7 +89,10 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. fsdp + qlora
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True