@@ -281,12 +281,22 @@ def init_adapter(
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
|
||||
# 3. is_trainable and not pure_bf16 and not badam
|
||||
if not is_trainable:
|
||||
cast_trainable_params_to_fp32 = False
|
||||
elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
elif model_args.quantization_bit is None and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
|
||||
):
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user