[model] support audio (#6701)

* support qwen2_audio

* improve code

* lint

* fix

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
Zhangchi Feng
2025-02-05 04:59:09 +08:00
committed by GitHub
parent 9feb78e7b4
commit 8f401e37f8
35 changed files with 675 additions and 213 deletions

View File

@@ -24,7 +24,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from trl import PreTrainedModelWrapper
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
@@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine):
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
prompt_ids,
None,
mm_input_dict["images"],
mm_input_dict["videos"],
mm_input_dict["audios"],
tokenizer,
processor,
)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
@@ -184,6 +196,9 @@ class HuggingfaceEngine(BaseEngine):
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
gen_kwargs["tokenizer"] = tokenizer
if "audio_feature_lens" in mm_inputs:
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
gen_kwargs.pop("image_sizes", None)
return gen_kwargs, prompt_length
@@ -201,6 +216,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@@ -214,6 +230,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
@@ -252,6 +269,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
@@ -265,6 +283,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
streamer = TextIteratorStreamer(
@@ -312,6 +331,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -329,6 +349,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
async with self.semaphore:
@@ -343,6 +364,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -360,6 +382,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
async with self.semaphore: