implement webui resuming training
Former-commit-id: 2d41672ef52414c56c50c8b4fdc442797ba682e9
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
|
||||
@@ -18,6 +18,9 @@ class Engine:
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
lang = config.get("lang", None) or "en"
|
||||
|
||||
@@ -34,12 +37,14 @@ class Engine:
|
||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||
|
||||
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
yield self._form_dict(resume_dict)
|
||||
|
||||
if self.runner.alive:
|
||||
pass # TODO: restore training
|
||||
if self.runner.alive: # TODO: resume eval
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
|
||||
resume_dict = {"train.resume_btn": {"value": True}}
|
||||
else:
|
||||
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
|
||||
resume_dict = {"train.output_dir": {"value": get_time()}}
|
||||
yield self._form_dict(resume_dict)
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user