update webui and add CLIs

Former-commit-id: 1368dda22ab875914c9dd86ee5146a4f6a4736ad
This commit is contained in:
hiyouga
2024-05-03 02:58:23 +08:00
parent 2cedb59bee
commit ce8200ad98
65 changed files with 363 additions and 372 deletions

View File

@@ -1,10 +1,13 @@
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import smooth
from ..extras.ploting import gen_loss_plot
from .locales import ALERTS
@@ -12,30 +15,6 @@ if is_gradio_available():
import gradio as gr
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps:
return gr.Slider(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return gr.Slider(label=label, value=percentage, visible=True)
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="none", interactive=False)
@@ -57,14 +36,19 @@ def check_json_schema(text: str, lang: str) -> None:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
args.pop("disable_tqdm", None)
args["plot_loss"] = args.get("do_train", None)
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
for k, v in args.items():
if v is not None and v is not False and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
@@ -76,29 +60,49 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
log_file = os.path.join(output_path, "trainer_log.jsonl")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return
def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)