finish agent
Former-commit-id: d8d9d3afe32725fe79120fcd1a0970fdcdc45625
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
@@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
if tools:
|
||||
result = self.template.format_tools.extract(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
yield chatbot, output_messages
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
|
||||
Reference in New Issue
Block a user