[model] fix kv cache (#7564)
This commit is contained in:
@@ -101,12 +101,10 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded
|
||||
use_streaming=data_args.streaming,
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
if data_args.streaming and data_args.dataset_shards:
|
||||
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
@@ -135,10 +133,10 @@ def _load_single_dataset(
|
||||
token=model_args.hf_hub_token,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
streaming=data_args.streaming and not data_args.dataset_shards,
|
||||
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
||||
)
|
||||
if data_args.streaming and data_args.dataset_shards:
|
||||
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
|
||||
if data_args.streaming and dataset_attr.load_from == "file":
|
||||
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
|
||||
@@ -1186,6 +1186,9 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
else:
|
||||
mm_inputs = {}
|
||||
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
@@ -1193,18 +1196,22 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
|
||||
assert mm_inputs.get("video_grid_thw", None) is not None, (
|
||||
"video_grid_thw should be exist when use_audio_in_video is `True`"
|
||||
)
|
||||
if audio_lengths is None:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if not mm_inputs.get("video_grid_thw", None):
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
positions = []
|
||||
@@ -1216,6 +1223,7 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
start = pos + len(special_token)
|
||||
|
||||
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||
|
||||
for message in messages:
|
||||
@@ -1278,6 +1286,7 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import torch
|
||||
@@ -278,10 +279,16 @@ def use_ray() -> bool:
|
||||
|
||||
def find_available_port() -> int:
|
||||
"""Find an available port on the local machine."""
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
return port
|
||||
|
||||
|
||||
def fix_proxy(ipv6_enabled: bool) -> None:
|
||||
"""Fix proxy settings for gradio ui."""
|
||||
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
||||
if ipv6_enabled:
|
||||
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
|
||||
os.environ.pop(name, None)
|
||||
|
||||
@@ -83,10 +83,6 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
dataset_shards: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
|
||||
@@ -135,7 +135,7 @@ def _check_extra_dependencies(
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == EngineName.VLLM:
|
||||
check_version("vllm>=0.4.3,<=0.8.1")
|
||||
check_version("vllm>=0.4.3,<=0.8.2")
|
||||
check_version("vllm", mandatory=True)
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.4")
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
@@ -277,14 +276,14 @@ def init_adapter(
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
|
||||
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
44
src/llamafactory/model/model_utils/kv_cache.py
Normal file
44
src/llamafactory/model/model_utils/kv_cache.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable:
|
||||
setattr(config, "use_cache", model_args.use_cache)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "use_cache", model_args.use_cache)
|
||||
|
||||
if model_args.use_cache:
|
||||
logger.info_rank0("KV cache is enabled for faster generation.")
|
||||
else:
|
||||
logger.info_rank0("KV cache is disabled.")
|
||||
else:
|
||||
setattr(config, "use_cache", False)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "use_cache", False)
|
||||
|
||||
logger.info_rank0("KV cache is disabled during training.")
|
||||
@@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.kv_cache import configure_kv_cache
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
@@ -102,23 +103,13 @@ def patch_config(
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info_rank0("Using KV cache for faster generation.")
|
||||
|
||||
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
|
||||
text_config = config.text_config
|
||||
setattr(text_config, "use_cache", False)
|
||||
configure_kv_cache(config, model_args, is_trainable)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", True)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
@@ -14,10 +14,8 @@
|
||||
|
||||
import os
|
||||
import platform
|
||||
import httpx
|
||||
|
||||
|
||||
from ..extras.misc import is_env_enabled
|
||||
from ..extras.misc import fix_proxy, is_env_enabled
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
@@ -74,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||
|
||||
def create_web_demo() -> "gr.Blocks":
|
||||
engine = Engine(pure_chat=True)
|
||||
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1)
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
@@ -90,30 +89,18 @@ def create_web_demo() -> "gr.Blocks":
|
||||
|
||||
|
||||
def run_web_ui() -> None:
|
||||
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
||||
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
|
||||
gradio_share = is_env_enabled("GRADIO_SHARE")
|
||||
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
|
||||
httpx.HTTPCORE_OPTS = {"trust_env": False}
|
||||
|
||||
try:
|
||||
demo = create_ui().queue()
|
||||
demo.launch(
|
||||
share=gradio_share,
|
||||
server_name=server_name,
|
||||
inbrowser=True,
|
||||
prevent_thread_lock=False,
|
||||
show_error=True,
|
||||
quiet=True,
|
||||
favicon_path=None
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error launching web UI: {str(e)}")
|
||||
raise
|
||||
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
|
||||
fix_proxy(ipv6_enabled=gradio_ipv6)
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
|
||||
gradio_share = is_env_enabled("GRADIO_SHARE")
|
||||
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
|
||||
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
|
||||
fix_proxy(ipv6_enabled=gradio_ipv6)
|
||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
|
||||
|
||||
Reference in New Issue
Block a user