modity code structure

Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
hiyouga
2023-07-15 16:54:28 +08:00
parent fa06b168ab
commit 6261fb362a
57 changed files with 1999 additions and 1816 deletions

View File

@@ -2,83 +2,26 @@
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import mdtex2html
import gradio as gr
from threading import Thread
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
from transformers import TextIteratorStreamer
from transformers.utils.versions import require_version
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
model_args, data_args, finetuning_args, generating_args = get_infer_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
def postprocess(self, y):
r"""
Overrides Chatbot.postprocess
"""
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
else:
lines[i] = "<br /></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br />" + line
text = "".join(lines)
return text
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((parse_text(query), ""))
chatbot.append((query, ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
@@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
for new_text in streamer:
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = (parse_text(query), parse_text(response))
chatbot[-1] = (query, response)
yield chatbot, new_history