fix eval resuming in webui

Former-commit-id: b28b53cd06777f213ef7b925a914ff5fd357ade1
This commit is contained in:
hiyouga
2023-10-15 15:45:38 +08:00
parent 7070f3969d
commit 68330eab2a
5 changed files with 27 additions and 22 deletions

View File

@@ -185,29 +185,16 @@ class Runner:
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, model_path, dataset, _, args = self._parse_train_args(data)
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, _, args = parse_func(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self.prepare(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self.prepare(data, do_train=False)
def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
@@ -221,6 +208,18 @@ class Runner:
self.thread.start()
yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
while self.thread.is_alive():