Former-commit-id: 3c2c45812a720d92f7f5b15b9f03370fe6bf069e
This commit is contained in:
hiyouga
2024-06-17 18:17:48 +08:00
parent 485a80d294
commit 60d9896a70
3 changed files with 25 additions and 14 deletions

View File

@@ -281,12 +281,22 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and quantization_bit is not None (qlora)
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
# 3. is_trainable and not pure_bf16 and not badam
if not is_trainable:
cast_trainable_params_to_fp32 = False
elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
elif model_args.quantization_bit is None and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
):
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else: