mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -24,7 +24,7 @@ from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
from ..data.mm_plugin import ImageInput, VideoInput
|
||||
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||
@@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
if audios is not None:
|
||||
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
|
||||
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
||||
prompt_ids,
|
||||
None,
|
||||
mm_input_dict["images"],
|
||||
mm_input_dict["videos"],
|
||||
mm_input_dict["audios"],
|
||||
tokenizer,
|
||||
processor,
|
||||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
@@ -184,6 +196,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||
gen_kwargs["input_ids"] = inputs
|
||||
gen_kwargs["tokenizer"] = tokenizer
|
||||
if "audio_feature_lens" in mm_inputs:
|
||||
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
|
||||
|
||||
gen_kwargs.pop("image_sizes", None)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
@@ -201,6 +216,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
@@ -214,6 +230,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
@@ -252,6 +269,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
@@ -265,6 +283,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(
|
||||
@@ -312,6 +331,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -329,6 +349,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -343,6 +364,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -360,6 +382,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
||||
Reference in New Issue
Block a user