rename package

Former-commit-id: a07ff0c083558cfe6f474d13027642d3052fee08
This commit is contained in:
hiyouga
2024-05-16 18:39:08 +08:00
parent fe638cf11f
commit dfa686b617
109 changed files with 31 additions and 31 deletions

View File

@@ -0,0 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
__all__ = [
"create_chat_box",
"create_eval_tab",
"create_export_tab",
"create_infer_tab",
"create_top",
"create_train_tab",
]

View File

@@ -0,0 +1,74 @@
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column():
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click(
engine.chatter.append,
[chatbot, messages, role, query],
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return (
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
image_box=image_box,
image=image,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
clear_btn=clear_btn,
),
)

View File

@@ -0,0 +1,106 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
PAGE_SIZE = 2
def prev_page(page_index: int) -> int:
return page_index - 1 if page_index > 0 else page_index
def next_page(page_index: int, total_num: int) -> int:
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.Button(interactive=False)
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
return gr.Button(interactive=False)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.Button(interactive=True)
else:
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path):
data = _load_data_file(data_path)
else:
data = []
for file_name in os.listdir(data_path):
data.extend(_load_data_file(os.path.join(data_path, file_name)))
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(value=0, interactive=False, precision=0)
page_index = gr.Number(value=0, interactive=False, precision=0)
with gr.Row():
prev_btn = gr.Button()
next_btn = gr.Button()
close_btn = gr.Button()
with gr.Row():
preview_samples = gr.JSON()
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
return dict(
data_preview_btn=data_preview_btn,
preview_count=preview_count,
page_index=page_index,
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples,
)

View File

@@ -0,0 +1,79 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
output_elems = [output_box, progress_bar]
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict

View File

@@ -0,0 +1,132 @@
from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def save_model(
lang: str,
model_name: str,
model_path: str,
adapter_path: List[str],
finetuning_type: str,
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,
export_dir: str,
export_hub_model_id: str,
) -> Generator[str, None, None]:
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif not export_dir:
error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
error = ALERTS["err_gptq_lora"][lang]
if error:
gr.Warning(error)
yield error
return
if adapter_path:
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
args = dict(
model_name_or_path=model_path,
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
)
yield ALERTS["info_exporting"][lang]
export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox()
with gr.Row():
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
save_model,
[
engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,
export_device,
export_legacy_format,
export_dir,
export_hub_model_id,
],
[info_box],
)
return dict(
export_size=export_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id,
export_btn=export_btn,
info_box=info_box,
)

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
)
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)
return elem_dict

View File

@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING, Dict
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
adapter_path=adapter_path,
refresh_btn=refresh_btn,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)

View File

@@ -0,0 +1,299 @@
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(
dict(
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_grad_norm=max_grad_norm,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row():
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
elem_dict.update(
dict(
cutoff_len=cutoff_len,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
val_size=val_size,
lr_scheduler_type=lr_scheduler_type,
)
)
with gr.Accordion(open=False) as extra_tab:
with gr.Row():
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
optim = gr.Textbox(value="adamw_torch")
with gr.Row():
with gr.Column():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
with gr.Column():
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
input_elems.update(
{
logging_steps,
save_steps,
warmup_steps,
neftune_alpha,
optim,
resize_vocab,
packing,
upcast_layernorm,
use_llama_pro,
shift_attn,
report_to,
}
)
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab,
packing=packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
)
)
with gr.Accordion(open=False) as freeze_tab:
with gr.Row():
freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
freeze_trainable_modules = gr.Textbox(value="all")
freeze_extra_modules = gr.Textbox()
input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
elem_dict.update(
dict(
freeze_tab=freeze_tab,
freeze_trainable_layers=freeze_trainable_layers,
freeze_trainable_modules=freeze_trainable_modules,
freeze_extra_modules=freeze_extra_modules,
)
)
with gr.Accordion(open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row():
with gr.Column(scale=1):
use_rslora = gr.Checkbox()
use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
input_elems.update(
{
lora_rank,
lora_alpha,
lora_dropout,
loraplus_lr_ratio,
create_new_adapter,
use_rslora,
use_dora,
lora_target,
additional_target,
}
)
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
loraplus_lr_ratio=loraplus_lr_ratio,
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
lora_target=lora_target,
additional_target=additional_target,
)
)
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
)
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
dict(
galore_tab=galore_tab,
use_galore=use_galore,
galore_rank=galore_rank,
galore_update_interval=galore_update_interval,
galore_scale=galore_scale,
galore_target=galore_target,
)
)
with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_interval=badam_switch_interval,
badam_update_ratio=badam_update_ratio,
)
)
with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
arg_load_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
arg_save_btn=arg_save_btn,
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
input_elems.update({output_dir, config_path})
output_elems = [output_box, progress_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict