[model] fix kv cache (#7564)
This commit is contained in:
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
@@ -277,14 +276,14 @@ def init_adapter(
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
|
||||
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
44
src/llamafactory/model/model_utils/kv_cache.py
Normal file
44
src/llamafactory/model/model_utils/kv_cache.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable:
|
||||
setattr(config, "use_cache", model_args.use_cache)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "use_cache", model_args.use_cache)
|
||||
|
||||
if model_args.use_cache:
|
||||
logger.info_rank0("KV cache is enabled for faster generation.")
|
||||
else:
|
||||
logger.info_rank0("KV cache is disabled.")
|
||||
else:
|
||||
setattr(config, "use_cache", False)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "use_cache", False)
|
||||
|
||||
logger.info_rank0("KV cache is disabled during training.")
|
||||
@@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.kv_cache import configure_kv_cache
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
@@ -102,23 +103,13 @@ def patch_config(
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info_rank0("Using KV cache for faster generation.")
|
||||
|
||||
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
|
||||
text_config = config.text_config
|
||||
setattr(text_config, "use_cache", False)
|
||||
configure_kv_cache(config, model_args, is_trainable)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", True)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
Reference in New Issue
Block a user