rename package

Former-commit-id: a07ff0c083558cfe6f474d13027642d3052fee08
This commit is contained in:
hiyouga
2024-05-16 18:39:08 +08:00
parent fe638cf11f
commit dfa686b617
109 changed files with 31 additions and 31 deletions

View File

View File

@@ -0,0 +1,145 @@
import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from ..chat import BaseEngine
from .manager import Manager
if is_gradio_available():
import gradio as gr
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
model_name_or_path = os.environ.get("DEMO_MODEL")
template = os.environ.get("DEMO_TEMPLATE")
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
super().__init__(
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
)
@property
def loaded(self) -> bool:
return self.engine is not None
def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
elif not get("top.model_name"):
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
if error:
gr.Warning(error)
yield error
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)
super().__init__(args)
yield ALERTS["info_loaded"][lang]
def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode:
gr.Warning(ALERTS["err_demo"][lang])
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.engine = None
torch_gc()
yield ALERTS["info_unloaded"][lang]
def append(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
role: str,
query: str,
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
def stream(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
image: Optional[NDArray],
max_new_tokens: int,
top_p: float,
temperature: float,
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
chatbot[-1][1] = ""
response = ""
for new_text in self.stream_chat(
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
if tools:
result = self.engine.template.format_tools.extract(response)
else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
chatbot[-1][1] = bot_text
yield chatbot, output_messages

View File

@@ -0,0 +1,155 @@
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
TRAINING_STAGES,
VISION_MODELS,
DownloadSource,
)
from ..extras.logging import get_logger
from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
logger = get_logger(__name__)
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
def get_save_dir(*args) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *args)
def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_path(config_path: str) -> os.PathLike:
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name:
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_save_path(config_path))
def get_model_path(model_name: str) -> str:
user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
if (
use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path
model_path = path_dict.get(DownloadSource.MODELSCOPE)
return model_path
def get_prefix(model_name: str) -> str:
return model_name.split("-")[0]
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
def get_template(model_name: str) -> str:
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"
def get_visual(model_name: str) -> bool:
return get_prefix(model_name) in VISION_MODELS
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False)
adapters = []
if model_name and finetuning_type == "lora":
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
):
adapters.append(adapter)
return gr.Dropdown(value=[], choices=adapters, interactive=True)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
if dataset_dir == "ONLINE":
logger.info("dataset_dir is ONLINE, using online dataset.")
return {}
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except Exception as err:
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
return {}
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(value=[], choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))

View File

@@ -0,0 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
__all__ = [
"create_chat_box",
"create_eval_tab",
"create_export_tab",
"create_infer_tab",
"create_top",
"create_train_tab",
]

View File

@@ -0,0 +1,74 @@
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column():
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click(
engine.chatter.append,
[chatbot, messages, role, query],
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return (
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
image_box=image_box,
image=image,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
clear_btn=clear_btn,
),
)

View File

@@ -0,0 +1,106 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
PAGE_SIZE = 2
def prev_page(page_index: int) -> int:
return page_index - 1 if page_index > 0 else page_index
def next_page(page_index: int, total_num: int) -> int:
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.Button(interactive=False)
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
return gr.Button(interactive=False)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.Button(interactive=True)
else:
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path):
data = _load_data_file(data_path)
else:
data = []
for file_name in os.listdir(data_path):
data.extend(_load_data_file(os.path.join(data_path, file_name)))
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(value=0, interactive=False, precision=0)
page_index = gr.Number(value=0, interactive=False, precision=0)
with gr.Row():
prev_btn = gr.Button()
next_btn = gr.Button()
close_btn = gr.Button()
with gr.Row():
preview_samples = gr.JSON()
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
return dict(
data_preview_btn=data_preview_btn,
preview_count=preview_count,
page_index=page_index,
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples,
)

View File

@@ -0,0 +1,79 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
output_elems = [output_box, progress_bar]
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict

View File

@@ -0,0 +1,132 @@
from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def save_model(
lang: str,
model_name: str,
model_path: str,
adapter_path: List[str],
finetuning_type: str,
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,
export_dir: str,
export_hub_model_id: str,
) -> Generator[str, None, None]:
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif not export_dir:
error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
error = ALERTS["err_gptq_lora"][lang]
if error:
gr.Warning(error)
yield error
return
if adapter_path:
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
args = dict(
model_name_or_path=model_path,
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
)
yield ALERTS["info_exporting"][lang]
export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox()
with gr.Row():
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
save_model,
[
engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,
export_device,
export_legacy_format,
export_dir,
export_hub_model_id,
],
[info_box],
)
return dict(
export_size=export_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id,
export_btn=export_btn,
info_box=info_box,
)

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
)
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)
return elem_dict

View File

@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING, Dict
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
adapter_path=adapter_path,
refresh_btn=refresh_btn,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)

View File

@@ -0,0 +1,299 @@
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(
dict(
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_grad_norm=max_grad_norm,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row():
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
elem_dict.update(
dict(
cutoff_len=cutoff_len,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
val_size=val_size,
lr_scheduler_type=lr_scheduler_type,
)
)
with gr.Accordion(open=False) as extra_tab:
with gr.Row():
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
optim = gr.Textbox(value="adamw_torch")
with gr.Row():
with gr.Column():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
with gr.Column():
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
input_elems.update(
{
logging_steps,
save_steps,
warmup_steps,
neftune_alpha,
optim,
resize_vocab,
packing,
upcast_layernorm,
use_llama_pro,
shift_attn,
report_to,
}
)
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab,
packing=packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
)
)
with gr.Accordion(open=False) as freeze_tab:
with gr.Row():
freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
freeze_trainable_modules = gr.Textbox(value="all")
freeze_extra_modules = gr.Textbox()
input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
elem_dict.update(
dict(
freeze_tab=freeze_tab,
freeze_trainable_layers=freeze_trainable_layers,
freeze_trainable_modules=freeze_trainable_modules,
freeze_extra_modules=freeze_extra_modules,
)
)
with gr.Accordion(open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row():
with gr.Column(scale=1):
use_rslora = gr.Checkbox()
use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
input_elems.update(
{
lora_rank,
lora_alpha,
lora_dropout,
loraplus_lr_ratio,
create_new_adapter,
use_rslora,
use_dora,
lora_target,
additional_target,
}
)
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
loraplus_lr_ratio=loraplus_lr_ratio,
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
lora_target=lora_target,
additional_target=additional_target,
)
)
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
)
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
dict(
galore_tab=galore_tab,
use_galore=use_galore,
galore_rank=galore_rank,
galore_update_interval=galore_update_interval,
galore_scale=galore_scale,
galore_target=galore_target,
)
)
with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_interval=badam_switch_interval,
badam_update_ratio=badam_update_ratio,
)
)
with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
arg_load_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
arg_save_btn=arg_save_btn,
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
input_elems.update({output_dir, config_path})
output_elems = [output_box, progress_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict

View File

@@ -0,0 +1,27 @@
CSS = r"""
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
.modal-box {
position: fixed !important;
top: 50%;
left: 50%;
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: auto;
background-color: var(--input-background-fill);
flex-wrap: nowrap !important;
border: 2px solid black !important;
z-index: 1000;
padding: 10px;
}
.dark .modal-box {
border: 2px solid white !important;
}
"""

View File

@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES
from .manager import Manager
from .runner import Runner
from .utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Engine:
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
self.pure_chat = pure_chat
self.manager = Manager()
self.runner = Runner(self.manager, demo_mode)
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r"""
Gets the dict to update the components.
"""
output_dict: Dict["Component", "Component"] = {}
for elem_id, elem_attr in input_dict.items():
elem = self.manager.get_elem_by_id(elem_id)
output_dict[elem] = elem.__class__(**elem_attr)
return output_dict
def resume(self):
user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset().choices}
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._update_component(init_dict)
if self.runner.running and not self.demo_mode and not self.pure_chat:
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
if self.runner.do_train:
yield self._update_component({"train.resume_btn": {"value": True}})
else:
yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str):
return {
elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter()
if elem_name in LOCALES
}

View File

@@ -0,0 +1,82 @@
import os
from ..extras.packages import is_gradio_available
from .common import save_config
from .components import (
create_chat_box,
create_eval_tab,
create_export_tab,
create_infer_tab,
create_top,
create_train_tab,
)
from .css import CSS
from .engine import Engine
if is_gradio_available():
import gradio as gr
def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML(
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elems("top", create_top())
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
with gr.Tab("Train"):
engine.manager.add_elems("train", create_train_tab(engine))
with gr.Tab("Evaluate & Predict"):
engine.manager.add_elems("eval", create_eval_tab(engine))
with gr.Tab("Chat"):
engine.manager.add_elems("infer", create_infer_tab(engine))
if not demo_mode:
with gr.Tab("Export"):
engine.manager.add_elems("export", create_export_tab(engine))
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo
def create_web_demo() -> gr.Blocks:
engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang))
_, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", chat_elems)
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo
def run_web_ui() -> None:
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name)
def run_web_demo() -> None:
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
r"""
Adds elements to manager.
"""
for elem_name, elem in elem_dict.items():
elem_id = "{}.{}".format(tab_name, elem_name)
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id
def get_elem_list(self) -> List["Component"]:
r"""
Returns the list of all elements.
"""
return list(self._id_to_elem.values())
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
r"""
Returns an iterator over all elements with their names.
"""
for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem
def get_elem_by_id(self, elem_id: str) -> "Component":
r"""
Gets element by id.
Example: top.lang, train.dataset
"""
return self._id_to_elem[elem_id]
def get_id_by_elem(self, elem: "Component") -> str:
r"""
Gets id by element.
"""
return self._elem_to_id[elem]
def get_base_elems(self) -> Set["Component"]:
r"""
Gets the base elements that are commonly used.
"""
return {
self._id_to_elem["top.lang"],
self._id_to_elem["top.model_name"],
self._id_to_elem["top.model_path"],
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.adapter_path"],
self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}

View File

@@ -0,0 +1,369 @@
import os
import signal
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available
from ..extras.constants import TRAINING_STAGES
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from .manager import Manager
class Runner:
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.trainer: Optional["Popen"] = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
""" State """
self.aborted = False
self.running = False
def set_abort(self) -> None:
self.aborted = True
if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
if self.running:
return ALERTS["err_conflict"][lang]
if not model_name:
return ALERTS["err_no_model"][lang]
if not model_path:
return ALERTS["err_no_path"][lang]
if not dataset:
return ALERTS["err_no_dataset"][lang]
if not from_preview and self.demo_mode:
return ALERTS["err_demo"][lang]
if not from_preview and get_device_count() > 1:
return ALERTS["err_device_count"][lang]
if do_train:
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and not is_torch_cuda_available():
gr.Warning(ALERTS["warn_no_cuda"][lang])
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
self.trainer = None
self.aborted = False
self.running = False
self.running_data = None
torch_gc()
return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
learning_rate=float(get("train.learning_rate")),
num_train_epochs=float(get("train.num_train_epochs")),
max_samples=int(get("train.max_samples")),
per_device_train_batch_size=get("train.batch_size"),
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
lr_scheduler_type=get("train.lr_scheduler_type"),
max_grad_norm=float(get("train.max_grad_norm")),
logging_steps=get("train.logging_steps"),
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"),
resize_vocab=get("train.resize_vocab"),
packing=get("train.packing"),
upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True,
)
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
elif args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout")
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["additional_target"] = get("train.additional_target") or None
if args["use_llama_pro"]:
args["num_layer_trainable"] = get("train.num_layer_trainable")
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("train.reward_model")
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
elif args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx")
elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),
max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=get("eval.batch_size"),
predict_with_generate=True,
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
)
if get("eval.predict"):
args["do_predict"] = True
else:
args["do_eval"] = True
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield {output_box: error}
else:
self.do_train, self.running_data = do_train, data
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
env["LLAMABOARD_ENABLED"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()
def preview_train(self, data):
yield from self._preview(data, do_train=True)
def preview_eval(self, data):
yield from self._preview(data, do_train=False)
def run_train(self, data):
yield from self._launch(data, do_train=True)
def run_eval(self, data):
yield from self._launch(data, do_train=False)
def monitor(self):
self.aborted = False
self.running = True
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.trainer is not None:
if self.aborted:
yield {
output_box: ALERTS["info_aborting"][lang],
progress_bar: gr.Slider(visible=False),
}
else:
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
return_dict = {
output_box: running_log,
progress_bar: running_progress,
}
if running_loss is not None:
return_dict[loss_viewer] = running_loss
yield return_dict
try:
self.trainer.wait(2)
self.trainer = None
except TimeoutExpired:
continue
if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_path, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
return_dict = {
output_box: self._finalize(lang, finish_info),
progress_bar: gr.Slider(visible=False),
}
yield return_dict
def save_args(self, data: dict):
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return {output_box: error}
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str):
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict

View File

@@ -0,0 +1,106 @@
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from .locales import ALERTS
if is_gradio_available():
import gradio as gr
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def check_json_schema(text: str, lang: str) -> None:
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)