add datasets

Former-commit-id: 02e4b47dea1b25905c61f2ace88bab112610f021
This commit is contained in:
hiyouga
2023-07-19 20:59:15 +08:00
parent a02b3e6192
commit 6d881f161b
9 changed files with 39 additions and 82 deletions

View File

@@ -26,7 +26,7 @@ class ChatModel:
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
prefix = prefix or self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
@@ -81,5 +81,4 @@ class ChatModel:
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
yield new_text
yield from streamer

View File

@@ -46,7 +46,7 @@ class Template:
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
prefix = prefix if prefix else self.prefix # use prefix if provided
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")]

View File

@@ -24,17 +24,9 @@ def main():
manager = Manager([{"lang": lang}, chat_elems])
demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)