[model] fix kv cache (#7564)
This commit is contained in:
@@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.kv_cache import configure_kv_cache
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
@@ -102,23 +103,13 @@ def patch_config(
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info_rank0("Using KV cache for faster generation.")
|
||||
|
||||
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
|
||||
text_config = config.text_config
|
||||
setattr(text_config, "use_cache", False)
|
||||
configure_kv_cache(config, model_args, is_trainable)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", True)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
Reference in New Issue
Block a user