finish agent
Former-commit-id: d8d9d3afe32725fe79120fcd1a0970fdcdc45625
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
@@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
if tools:
|
||||
result = self.template.format_tools.extract(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
yield chatbot, output_messages
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
|
||||
@@ -17,7 +17,7 @@ def create_chat_box(
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
history = gr.State([])
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
system = gr.Textbox(show_label=False)
|
||||
@@ -32,21 +32,21 @@ def create_chat_box(
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
tools.input(check_json_schema, [tools])
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||
|
||||
return (
|
||||
chat_box,
|
||||
chatbot,
|
||||
history,
|
||||
messages,
|
||||
dict(
|
||||
system=system,
|
||||
tools=tools,
|
||||
|
||||
@@ -208,6 +208,8 @@ ALERTS = {
|
||||
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
||||
},
|
||||
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
||||
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
|
||||
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
|
||||
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
||||
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
||||
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
||||
|
||||
@@ -8,6 +8,7 @@ import gradio as gr
|
||||
from ..extras.packages import is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def check_json_schema(text: str) -> None:
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
try:
|
||||
json.loads(text)
|
||||
tools = json.loads(text)
|
||||
for tool in tools:
|
||||
assert "name" in tool
|
||||
except AssertionError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except json.JSONDecodeError:
|
||||
gr.Warning("Invalid JSON schema")
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user