[model] fix kv cache (#7564)

This commit is contained in:
hoshi-hiyouga
2025-04-01 23:07:46 +08:00
committed by GitHub
parent a13b1bb49a
commit 2bfcad2394
16 changed files with 122 additions and 64 deletions

View File

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
@@ -277,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
else:
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True

View File

@@ -0,0 +1,44 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
if model_args.use_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")
else:
setattr(config, "use_cache", False)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", False)
logger.info_rank0("KV cache is disabled during training.")

View File

@@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
@@ -102,23 +103,13 @@ def patch_config(
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info_rank0("Using KV cache for faster generation.")
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
text_config = config.text_config
setattr(text_config, "use_cache", False)
configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", True)
setattr(config, "init_tts", False)