upgrade gradio to 4.21.0
Former-commit-id: 63eecbeb967d849e1d03d8d03fb6421c0ee89257
This commit is contained in:
@@ -48,8 +48,8 @@ class Runner:
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
@@ -95,8 +95,8 @@ class Runner:
|
||||
else:
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
@@ -196,8 +196,8 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
@@ -232,6 +232,7 @@ class Runner:
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if get("eval.predict"):
|
||||
args["do_predict"] = True
|
||||
@@ -240,22 +241,20 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(
|
||||
self, data: Dict[Component, Any], do_train: bool
|
||||
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
yield error, gr.Slider(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
yield gen_cmd(args), gr.Slider(visible=False)
|
||||
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
yield error, gr.Slider(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
@@ -264,20 +263,20 @@ class Runner:
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.running = True
|
||||
lang = get("top.lang")
|
||||
output_dir = get_save_dir(
|
||||
@@ -286,13 +285,14 @@ class Runner:
|
||||
get("{}.output_dir".format("train" if self.do_train else "eval")),
|
||||
)
|
||||
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
@@ -304,4 +304,4 @@ class Runner:
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
|
||||
|
||||
Reference in New Issue
Block a user