update webui and add CLIs
Former-commit-id: 1368dda22ab875914c9dd86ee5146a4f6a4736ad
This commit is contained in:
@@ -1,22 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
import transformers
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -34,24 +31,18 @@ class Runner:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
||||
os.kill(children.pid, signal.SIGABRT)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
@@ -85,13 +76,11 @@ class Runner:
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.thread = None
|
||||
self.trainer = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
@@ -270,11 +259,12 @@ class Runner:
|
||||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
env = deepcopy(os.environ)
|
||||
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data):
|
||||
@@ -291,9 +281,6 @@ class Runner:
|
||||
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
@@ -301,28 +288,31 @@ class Runner:
|
||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
while self.trainer is not None:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path)
|
||||
return_dict = {
|
||||
output_box: self.logger_handler.log,
|
||||
process_bar: update_process_bar(self.trainer_callback),
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
if self.do_train and running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
time.sleep(2)
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
@@ -337,16 +327,11 @@ class Runner:
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info),
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
def save_args(self, data: dict):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
|
||||
Reference in New Issue
Block a user