modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
@@ -2,83 +2,26 @@
|
||||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from threading import Thread
|
||||
from utils import (
|
||||
Template,
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||
|
||||
|
||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
|
||||
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
r"""
|
||||
Overrides Chatbot.postprocess
|
||||
"""
|
||||
if y is None:
|
||||
return []
|
||||
for i, (message, response) in enumerate(y):
|
||||
y[i] = (
|
||||
None if message is None else mdtex2html.convert((message)),
|
||||
None if response is None else mdtex2html.convert(response),
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
|
||||
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split("`")
|
||||
if count % 2 == 1:
|
||||
lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
|
||||
else:
|
||||
lines[i] = "<br /></code></pre>"
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", "\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br />" + line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
|
||||
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
||||
chatbot.append((parse_text(query), ""))
|
||||
chatbot.append((query, ""))
|
||||
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
@@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
||||
for new_text in streamer:
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = (parse_text(query), parse_text(response))
|
||||
chatbot[-1] = (query, response)
|
||||
yield chatbot, new_history
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user