format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,24 +1,22 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "Manager",
|
||||
demo_mode: Optional[bool] = False,
|
||||
lazy_init: Optional[bool] = True
|
||||
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
@@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
|
||||
if demo_mode: # load demo_config.json if exists
|
||||
if demo_mode: # load demo_config.json if exists
|
||||
import json
|
||||
|
||||
try:
|
||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
||||
args = json.load(f)
|
||||
@@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
|
||||
super().__init__(args)
|
||||
except AssertionError:
|
||||
print("Please provided model name and template in `demo_config.json`.")
|
||||
except:
|
||||
except Exception:
|
||||
print("Cannot find `demo_config.json` at current directory.")
|
||||
|
||||
@property
|
||||
@@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
|
||||
return
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
|
||||
template=get("top.template"),
|
||||
flash_attn=(get("top.booster") == "flash_attn"),
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
@@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
|
||||
Reference in New Issue
Block a user