[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -172,6 +172,7 @@ class WebChatModel(ChatModel):
|
||||
tools: str,
|
||||
image: Optional[Any],
|
||||
video: Optional[Any],
|
||||
audio: Optional[Any],
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
@@ -190,6 +191,7 @@ class WebChatModel(ChatModel):
|
||||
tools,
|
||||
images=[image] if image else None,
|
||||
videos=[video] if video else None,
|
||||
audios=[audio] if audio else None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -26,9 +26,9 @@ from ..extras import logging
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
DEFAULT_TEMPLATE,
|
||||
MULTIMODAL_SUPPORTED_MODELS,
|
||||
SUPPORTED_MODELS,
|
||||
TRAINING_ARGS,
|
||||
VISION_MODELS,
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.misc import use_modelscope, use_openmind
|
||||
@@ -136,13 +136,6 @@ def get_template(model_name: str) -> str:
|
||||
return DEFAULT_TEMPLATE.get(model_name, "default")
|
||||
|
||||
|
||||
def get_visual(model_name: str) -> bool:
|
||||
r"""
|
||||
Judges if the model is a vision language model.
|
||||
"""
|
||||
return model_name in VISION_MODELS
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
r"""
|
||||
Gets current date and time.
|
||||
@@ -150,6 +143,13 @@ def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def is_multimodal(model_name: str) -> bool:
|
||||
r"""
|
||||
Judges if the model is a vision language model.
|
||||
"""
|
||||
return model_name in MULTIMODAL_SUPPORTED_MODELS
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
r"""
|
||||
Loads dataset_info.json.
|
||||
|
||||
@@ -64,10 +64,13 @@ def create_chat_box(
|
||||
|
||||
with gr.Column() as mm_box:
|
||||
with gr.Tab("Image"):
|
||||
image = gr.Image(sources=["upload"], type="pil")
|
||||
image = gr.Image(type="pil")
|
||||
|
||||
with gr.Tab("Video"):
|
||||
video = gr.Video(sources=["upload"])
|
||||
video = gr.Video()
|
||||
|
||||
with gr.Tab("Audio"):
|
||||
audio = gr.Audio(type="filepath")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
@@ -86,7 +89,7 @@ def create_chat_box(
|
||||
[chatbot, messages, query],
|
||||
).then(
|
||||
engine.chatter.stream,
|
||||
[chatbot, messages, lang, system, tools, image, video, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages, lang, system, tools, image, video, audio, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
@@ -102,6 +105,7 @@ def create_chat_box(
|
||||
mm_box=mm_box,
|
||||
image=image,
|
||||
video=video,
|
||||
audio=audio,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_visual
|
||||
from ..common import is_multimodal
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.model_name").change(
|
||||
lambda model_name: gr.Column(visible=get_visual(model_name)),
|
||||
lambda model_name: gr.Column(visible=is_multimodal(model_name)),
|
||||
[engine.manager.get_elem_by_id("top.model_name")],
|
||||
[chat_elems["mm_box"]],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user