fix bug in latest gradio

Former-commit-id: 44a962862b4a74e50ef5786c8d5719faaa65f63f
This commit is contained in:
hiyouga
2024-04-04 00:55:31 +08:00
parent 43d134ba29
commit b1986a06b9
8 changed files with 111 additions and 204 deletions

View File

@@ -66,7 +66,7 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@@ -21,8 +21,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@@ -50,7 +48,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
@@ -73,4 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict

View File

@@ -6,7 +6,6 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
@@ -24,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
@@ -121,8 +120,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as freeze_tab:
with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
name_module_trainable = gr.Textbox(value="all", scale=3)
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
name_module_trainable = gr.Textbox(value="all")
input_elems.update({num_layer_trainable, name_module_trainable})
elem_dict.update(
@@ -140,8 +139,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox()
with gr.Row():
use_rslora = gr.Checkbox(scale=1)
use_dora = gr.Checkbox(scale=1)
with gr.Column(scale=1):
use_rslora = gr.Checkbox()
use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
@@ -175,10 +176,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update(
@@ -187,11 +188,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox(scale=1)
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
galore_target = gr.Textbox(value="all", scale=3)
use_galore = gr.Checkbox()
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
@@ -228,29 +229,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems),
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
@@ -267,15 +245,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
output_box.change(
gen_plot,
[
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.finetuning_type"),
output_dir,
],
loss_viewer,
queue=False,
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict

View File

@@ -1344,6 +1344,11 @@ ALERTS = {
"ru": "Аргументы были сохранены по адресу: ",
"zh": "训练参数已保存至:",
},
"info_config_loaded": {
"en": "Arguments have been restored.",
"ru": "Аргументы были восстановлены.",
"zh": "训练参数已载入。",
},
"info_loading": {
"en": "Loading model...",
"ru": "Загрузка модели...",

View File

@@ -2,7 +2,7 @@ import logging
import os
import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers
@@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if TYPE_CHECKING:
@@ -239,20 +239,22 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.Slider(visible=False)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
@@ -261,54 +263,80 @@ class Runner:
self.thread.start()
yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive():
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
yield {
output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False),
}
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
return_dict = {
output_box: self.logger_handler.log,
process_bar: update_process_bar(self.trainer_callback),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict
time.sleep(2)
if self.do_train:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
if os.path.exists(os.path.join(output_path, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
return_dict = {
output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return error, gr.Slider(visible=False)
return {output_box: error}
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
@@ -320,15 +348,16 @@ class Runner:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang}
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value

View File

@@ -1,13 +1,12 @@
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS
@@ -36,7 +35,7 @@ def get_time() -> str:
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="None", interactive=False)
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
@@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
log_file = os.path.join(output_path, "trainer_log.jsonl")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
plt.close("all")
@@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return None
return
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")