web UI integrating RLHF

Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
This commit is contained in:
hiyouga
2023-08-14 10:48:47 +08:00
parent 4933ab5956
commit 688e8601ab
11 changed files with 128 additions and 32 deletions

View File

@@ -1,5 +1,5 @@
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab

View File

@@ -20,22 +20,25 @@ def create_top() -> Dict[str, "Component"]:
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(["", "8", "4"], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
@@ -43,7 +46,9 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict(
lang=lang,

View File

@@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
@@ -40,7 +40,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
@@ -60,6 +60,20 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[reward_model],
queue=False
)
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
@@ -79,7 +93,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_list = [
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
@@ -108,16 +122,19 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_dropout,
lora_target,
resume_lora_training,
rlhf_method,
dpo_beta,
reward_model,
output_dir
]
output_list = [
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_list, output_list)
start_btn.click(runner.run_train, input_list, output_list)
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
process_bar.change(
@@ -152,6 +169,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
rlhf_method=rlhf_method,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,