support mllm hf inference

Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
This commit is contained in:
hiyouga
2024-04-26 05:34:58 +08:00
parent 10a6c395bb
commit 23b881bff1
23 changed files with 128 additions and 49 deletions

View File

@@ -2,6 +2,8 @@ import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
@@ -112,6 +114,7 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
image: Optional[NDArray],
max_new_tokens: int,
top_p: float,
temperature: float,
@@ -119,7 +122,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = ""
response = ""
for new_text in self.stream_chat(
messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
if tools:

View File

@@ -23,9 +23,15 @@ def create_chat_box(
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=2)
with gr.Row():
with gr.Column():
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=4)
with gr.Column():
image = gr.Image(type="numpy")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -43,7 +49,7 @@ def create_chat_box(
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
@@ -56,6 +62,7 @@ def create_chat_box(
role=role,
system=system,
tools=tools,
image=image,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,

View File

@@ -1073,6 +1073,17 @@ LOCALES = {
"placeholder": "工具列表(非必填)",
},
},
"image": {
"en": {
"label": "Image (optional)",
},
"ru": {
"label": "Изображение (по желанию)",
},
"zh": {
"label": "图像(非必填)",
},
},
"query": {
"en": {
"placeholder": "Input...",