format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,11 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
"create_top",
"create_train_tab",
"create_eval_tab",
"create_infer_tab",
"create_export_tab",
"create_chat_box",
]

View File

@@ -1,6 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from ..utils import check_json_schema
@@ -12,8 +13,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine",
visible: Optional[bool] = False
engine: "Engine", visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
@@ -38,20 +38,23 @@ def create_chat_box(
engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
return (
chat_box,
chatbot,
history,
dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
),
)

View File

@@ -1,10 +1,12 @@
import os
import json
import gradio as gr
import os
from typing import TYPE_CHECKING, Any, Dict, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component
@@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except:
except Exception:
return gr.update(interactive=False)
if (
@@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f]
data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
@@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
with gr.Row():
preview_samples = gr.JSON(interactive=False)
dataset.change(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(
prev_page, [page_index], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(
next_page, [page_index, preview_count], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return dict(
@@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples
preview_samples=preview_samples,
)

View File

@@ -1,9 +1,11 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from ..common import list_dataset, DEFAULT_DATA_DIR
import gradio as gr
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
@@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
@@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
))
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
@@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box = gr.Markdown()
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)

View File

@@ -1,10 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
@@ -24,7 +26,7 @@ def save_model(
max_shard_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_dir: str
export_dir: str,
) -> Generator[str, None, None]:
error = ""
if not model_name:
@@ -44,7 +46,9 @@ def save_model(
return
if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
@@ -56,7 +60,7 @@ def save_model(
export_dir=export_dir,
export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset
export_quantization_dataset=export_quantization_dataset,
)
yield ALERTS["info_exporting"][lang]
@@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size,
export_quantization_bit,
export_quantization_dataset,
export_dir
export_dir,
],
[info_box]
[info_box],
)
return dict(
@@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir,
export_btn=export_btn,
info_box=info_box
info_box=info_box,
)

View File

@@ -1,8 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
@@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click(
engine.chatter.load_model, input_elems, [info_box]
).then(
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
unload_btn.click(
engine.chatter.unload_model, input_elems, [info_box]
).then(
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
return elem_dict

View File

@@ -1,11 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
@@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
model_name.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
lang=lang,
@@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit=quantization_bit,
template=template,
rope_scaling=rope_scaling,
booster=booster
booster=booster,
)

View File

@@ -1,12 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
@@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
@@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
max_samples=max_samples, compute_type=compute_type
))
elem_dict.update(
dict(
cutoff_len=cutoff_len,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
))
elem_dict.update(
dict(
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
)
)
with gr.Accordion(label="Extra config", open=False) as extra_tab:
with gr.Row():
@@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
elem_dict.update(dict(
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
))
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm,
)
)
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
@@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox(scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
elem_dict.update(dict(
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
additional_target=additional_target, create_new_adapter=create_new_adapter
))
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
additional_target=additional_target,
create_new_adapter=create_new_adapter,
)
)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
@@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False
queue=False,
)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)
with gr.Row():
cmd_preview_btn = gr.Button()
@@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
output_box.change(
gen_plot,
[
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir
output_dir,
],
loss_viewer,
queue=False
queue=False,
)
return elem_dict