Former-commit-id: 2d8d47f6126d68db1701ed18fc31310c6f14dd49
This commit is contained in:
hiyouga
2024-06-20 22:56:05 +08:00
parent f16a4a8264
commit af2cb33bb2
4 changed files with 15 additions and 13 deletions

View File

@@ -289,16 +289,15 @@ def init_adapter(
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and quantization_bit is not None (qlora)
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
# 3. is_trainable and not pure_bf16 and not badam
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
cast_trainable_params_to_fp32 = False
elif model_args.quantization_bit is None and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
):
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True