implement webui resuming training
Former-commit-id: 2d41672ef52414c56c50c8b4fdc442797ba682e9
This commit is contained in:
@@ -27,6 +27,9 @@ class Runner:
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
self.thread: "Thread" = None
|
||||
self.data: Dict["Component", Any] = None
|
||||
self.do_train = True
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.logger_handler = LoggerHandler()
|
||||
@@ -199,14 +202,15 @@ class Runner:
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self.prepare(data, self._parse_train_args)
|
||||
yield from self.prepare(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self.prepare(data, self._parse_eval_args)
|
||||
yield from self.prepare(data, do_train=False)
|
||||
|
||||
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
||||
def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
||||
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
@@ -215,9 +219,10 @@ class Runner:
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor(lang, output_dir, is_training)
|
||||
yield from self.monitor()
|
||||
|
||||
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
@@ -225,7 +230,7 @@ class Runner:
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if is_training:
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user