format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,24 +1,22 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from .manager import Manager
class WebChatModel(ChatModel):
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
@@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load demo_config.json if exists
if demo_mode: # load demo_config.json if exists
import json
try:
with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f)
@@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
super().__init__(args)
except AssertionError:
print("Please provided model name and template in `demo_config.json`.")
except:
except Exception:
print("Cannot find `demo_config.json` at current directory.")
@property
@@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
)
super().__init__(args)
@@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
tools: str,
max_new_tokens: int,
top_p: float,
temperature: float
temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
response = ""