[model] fix kv cache (#7564)

This commit is contained in:
hoshi-hiyouga
2025-04-01 23:07:46 +08:00
committed by GitHub
parent a13b1bb49a
commit 2bfcad2394
16 changed files with 122 additions and 64 deletions

View File

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
@@ -277,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
else:
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True