tiny fix
Former-commit-id: 2d8d47f6126d68db1701ed18fc31310c6f14dd49
This commit is contained in:
@@ -289,16 +289,15 @@ def init_adapter(
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
|
||||
# 3. is_trainable and not pure_bf16 and not badam
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
cast_trainable_params_to_fp32 = False
|
||||
elif model_args.quantization_bit is None and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
|
||||
):
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
Reference in New Issue
Block a user