rename package
Former-commit-id: a07ff0c083558cfe6f474d13027642d3052fee08
This commit is contained in:
0
src/llamafactory/webui/__init__.py
Normal file
0
src/llamafactory/webui/__init__.py
Normal file
145
src/llamafactory/webui/chatter.py
Normal file
145
src/llamafactory/webui/chatter.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..chat import BaseEngine
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.engine: Optional["BaseEngine"] = None
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
|
||||
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
||||
model_name_or_path = os.environ.get("DEMO_MODEL")
|
||||
template = os.environ.get("DEMO_TEMPLATE")
|
||||
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
|
||||
super().__init__(
|
||||
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
|
||||
)
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.engine is not None
|
||||
|
||||
def load_model(self, data) -> Generator[str, None, None]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
elif not get("top.model_name"):
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif self.demo_mode:
|
||||
error = ALERTS["err_demo"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
gr.Warning(ALERTS["err_demo"][lang])
|
||||
yield ALERTS["err_demo"][lang]
|
||||
return
|
||||
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.engine = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def append(
|
||||
self,
|
||||
chatbot: List[List[Optional[str]]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
|
||||
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
|
||||
|
||||
def stream(
|
||||
self,
|
||||
chatbot: List[List[Optional[str]]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[NDArray],
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
|
||||
chatbot[-1][1] = ""
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
result = self.engine.template.format_tools.extract(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1][1] = bot_text
|
||||
yield chatbot, output_messages
|
||||
155
src/llamafactory/webui/common.py
Normal file
155
src/llamafactory/webui/common.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from yaml import safe_dump, safe_load
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
PEFT_METHODS,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
SUPPORTED_MODELS,
|
||||
TRAINING_STAGES,
|
||||
VISION_MODELS,
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, *args)
|
||||
|
||||
|
||||
def get_config_path() -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def get_save_path(config_path: str) -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
safe_dump(user_config, f)
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
return str(get_save_path(config_path))
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
||||
if (
|
||||
use_modelscope()
|
||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||
): # replace path
|
||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||
return model_path
|
||||
|
||||
|
||||
def get_prefix(model_name: str) -> str:
|
||||
return model_name.split("-")[0]
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||
return "default"
|
||||
|
||||
|
||||
def get_visual(model_name: str) -> bool:
|
||||
return get_prefix(model_name) in VISION_MODELS
|
||||
|
||||
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type not in PEFT_METHODS:
|
||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
||||
|
||||
adapters = []
|
||||
if model_name and finetuning_type == "lora":
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for adapter in os.listdir(save_dir):
|
||||
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
|
||||
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
||||
):
|
||||
adapters.append(adapter)
|
||||
return gr.Dropdown(value=[], choices=adapters, interactive=True)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
if dataset_dir == "ONLINE":
|
||||
logger.info("dataset_dir is ONLINE, using online dataset.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as err:
|
||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.Dropdown(value=[], choices=datasets)
|
||||
|
||||
|
||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
|
||||
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))
|
||||
16
src/llamafactory/webui/components/__init__.py
Normal file
16
src/llamafactory/webui/components/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .chatbot import create_chat_box
|
||||
from .eval import create_eval_tab
|
||||
from .export import create_export_tab
|
||||
from .infer import create_infer_tab
|
||||
from .top import create_top
|
||||
from .train import create_train_tab
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_chat_box",
|
||||
"create_eval_tab",
|
||||
"create_export_tab",
|
||||
"create_infer_tab",
|
||||
"create_top",
|
||||
"create_train_tab",
|
||||
]
|
||||
74
src/llamafactory/webui/components/chatbot.py
Normal file
74
src/llamafactory/webui/components/chatbot.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=3)
|
||||
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(sources=["upload"], type="numpy")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.append,
|
||||
[chatbot, messages, role, query],
|
||||
[chatbot, messages, query],
|
||||
).then(
|
||||
engine.chatter.stream,
|
||||
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
|
||||
return (
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
chat_box=chat_box,
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
image_box=image_box,
|
||||
image=image,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
clear_btn=clear_btn,
|
||||
),
|
||||
)
|
||||
106
src/llamafactory/webui/components/data.py
Normal file
106
src/llamafactory/webui/components/data.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
def prev_page(page_index: int) -> int:
|
||||
return page_index - 1 if page_index > 0 else page_index
|
||||
|
||||
|
||||
def next_page(page_index: int, total_num: int) -> int:
|
||||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif file_path.endswith(".jsonl"):
|
||||
return [json.loads(line) for line in f]
|
||||
else:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path):
|
||||
data = _load_data_file(data_path)
|
||||
else:
|
||||
data = []
|
||||
for file_name in os.listdir(data_path):
|
||||
data.extend(_load_data_file(os.path.join(data_path, file_name)))
|
||||
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
||||
page_index = gr.Number(value=0, interactive=False, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button()
|
||||
next_btn = gr.Button()
|
||||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON()
|
||||
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
)
|
||||
data_preview_btn.click(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
page_index=page_index,
|
||||
prev_btn=prev_btn,
|
||||
next_btn=next_btn,
|
||||
close_btn=close_btn,
|
||||
preview_samples=preview_samples,
|
||||
)
|
||||
79
src/llamafactory/webui/components/eval.py
Normal file
79
src/llamafactory/webui/components/eval.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, progress_bar]
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
return elem_dict
|
||||
132
src/llamafactory/webui/components/export.py
Normal file
132
src/llamafactory/webui/components/export.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_device: str,
|
||||
export_legacy_format: bool,
|
||||
export_dir: str,
|
||||
export_hub_model_id: str,
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
if not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif not export_dir:
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||
error = ALERTS["err_no_dataset"][lang]
|
||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||
error = ALERTS["err_no_adapter"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and adapter_path:
|
||||
error = ALERTS["err_gptq_lora"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if adapter_path:
|
||||
adapter_name_or_path = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id or None,
|
||||
export_size=export_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args)
|
||||
torch_gc()
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
export_dir = gr.Textbox()
|
||||
export_hub_model_id = gr.Textbox()
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
engine.manager.get_elem_by_id("top.lang"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.model_path"),
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_device,
|
||||
export_legacy_format,
|
||||
export_dir,
|
||||
export_hub_model_id,
|
||||
],
|
||||
[info_box],
|
||||
)
|
||||
|
||||
return dict(
|
||||
export_size=export_size,
|
||||
export_quantization_bit=export_quantization_bit,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box,
|
||||
)
|
||||
48
src/llamafactory/webui/components/infer.py
Normal file
48
src/llamafactory/webui/components/infer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
66
src/llamafactory/webui/components/top.py
Normal file
66
src/llamafactory/webui/components/top.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(get_template, [model_name], [template], queue=False).then(
|
||||
get_visual, [model_name], [visual_inputs], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
adapter_path=adapter_path,
|
||||
refresh_btn=refresh_btn,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
299
src/llamafactory/webui/components/train.py
Normal file
299
src/llamafactory/webui/components/train.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
||||
|
||||
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_grad_norm=max_grad_norm,
|
||||
max_samples=max_samples,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
|
||||
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cutoff_len=cutoff_len,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
val_size=val_size,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
|
||||
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
|
||||
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
|
||||
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
|
||||
optim = gr.Textbox(value="adamw_torch")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
resize_vocab = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
shift_attn = gr.Checkbox()
|
||||
report_to = gr.Checkbox()
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
neftune_alpha,
|
||||
optim,
|
||||
resize_vocab,
|
||||
packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
report_to,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
extra_tab=extra_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha,
|
||||
optim=optim,
|
||||
resize_vocab=resize_vocab,
|
||||
packing=packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
report_to=report_to,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
|
||||
freeze_trainable_modules = gr.Textbox(value="all")
|
||||
freeze_extra_modules = gr.Textbox()
|
||||
|
||||
input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab,
|
||||
freeze_trainable_layers=freeze_trainable_layers,
|
||||
freeze_trainable_modules=freeze_trainable_modules,
|
||||
freeze_extra_modules=freeze_extra_modules,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
|
||||
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
loraplus_lr_ratio,
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox()
|
||||
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
galore_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
galore_tab=galore_tab,
|
||||
use_galore=use_galore,
|
||||
galore_rank=galore_rank,
|
||||
galore_update_interval=galore_update_interval,
|
||||
galore_scale=galore_scale,
|
||||
galore_target=galore_target,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as badam_tab:
|
||||
with gr.Row():
|
||||
use_badam = gr.Checkbox()
|
||||
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
|
||||
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
|
||||
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
|
||||
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
|
||||
|
||||
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
badam_tab=badam_tab,
|
||||
use_badam=use_badam,
|
||||
badam_mode=badam_mode,
|
||||
badam_switch_mode=badam_switch_mode,
|
||||
badam_switch_interval=badam_switch_interval,
|
||||
badam_update_ratio=badam_update_ratio,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
arg_load_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
arg_save_btn=arg_save_btn,
|
||||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
engine.runner.load_args,
|
||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||
list(input_elems) + [output_box],
|
||||
concurrency_limit=None,
|
||||
)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
return elem_dict
|
||||
27
src/llamafactory/webui/css.py
Normal file
27
src/llamafactory/webui/css.py
Normal file
@@ -0,0 +1,27 @@
|
||||
CSS = r"""
|
||||
.duplicate-button {
|
||||
margin: auto !important;
|
||||
color: white !important;
|
||||
background: black !important;
|
||||
border-radius: 100vh !important;
|
||||
}
|
||||
|
||||
.modal-box {
|
||||
position: fixed !important;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%); /* center horizontally */
|
||||
max-width: 1000px;
|
||||
max-height: 750px;
|
||||
overflow-y: auto;
|
||||
background-color: var(--input-background-fill);
|
||||
flex-wrap: nowrap !important;
|
||||
border: 2px solid black !important;
|
||||
z-index: 1000;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.dark .modal-box {
|
||||
border: 2px solid white !important;
|
||||
}
|
||||
"""
|
||||
66
src/llamafactory/webui/engine.py
Normal file
66
src/llamafactory/webui/engine.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .chatter import WebChatModel
|
||||
from .common import get_model_path, list_dataset, load_config
|
||||
from .locales import LOCALES
|
||||
from .manager import Manager
|
||||
from .runner import Runner
|
||||
from .utils import get_time
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
self.pure_chat = pure_chat
|
||||
self.manager = Manager()
|
||||
self.runner = Runner(self.manager, demo_mode)
|
||||
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
|
||||
|
||||
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
|
||||
r"""
|
||||
Gets the dict to update the components.
|
||||
"""
|
||||
output_dict: Dict["Component", "Component"] = {}
|
||||
for elem_id, elem_attr in input_dict.items():
|
||||
elem = self.manager.get_elem_by_id(elem_id)
|
||||
output_dict[elem] = elem.__class__(**elem_attr)
|
||||
|
||||
return output_dict
|
||||
|
||||
def resume(self):
|
||||
user_config = load_config() if not self.demo_mode else {}
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||
|
||||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
yield self._update_component(init_dict)
|
||||
|
||||
if self.runner.running and not self.demo_mode and not self.pure_chat:
|
||||
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._update_component({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||
|
||||
def change_lang(self, lang: str):
|
||||
return {
|
||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||
for elem_name, elem in self.manager.get_elem_iter()
|
||||
if elem_name in LOCALES
|
||||
}
|
||||
82
src/llamafactory/webui/interface.py
Normal file
82
src/llamafactory/webui/interface.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
create_chat_box,
|
||||
create_eval_tab,
|
||||
create_export_tab,
|
||||
create_infer_tab,
|
||||
create_top,
|
||||
create_train_tab,
|
||||
)
|
||||
from .css import CSS
|
||||
from .engine import Engine
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
if demo_mode:
|
||||
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
|
||||
gr.HTML(
|
||||
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
|
||||
"LLaMA Factory</a> for details.</center></h3>"
|
||||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.add_elems("top", create_top())
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.add_elems("train", create_train_tab(engine))
|
||||
|
||||
with gr.Tab("Evaluate & Predict"):
|
||||
engine.manager.add_elems("eval", create_eval_tab(engine))
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
engine.manager.add_elems("infer", create_infer_tab(engine))
|
||||
|
||||
if not demo_mode:
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.add_elems("export", create_export_tab(engine))
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def create_web_demo() -> gr.Blocks:
|
||||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elems("infer", chat_elems)
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def run_web_ui() -> None:
|
||||
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)
|
||||
1524
src/llamafactory/webui/locales.py
Normal file
1524
src/llamafactory/webui/locales.py
Normal file
File diff suppressed because it is too large
Load Diff
64
src/llamafactory/webui/manager.py
Normal file
64
src/llamafactory/webui/manager.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Manager:
|
||||
def __init__(self) -> None:
|
||||
self._id_to_elem: Dict[str, "Component"] = {}
|
||||
self._elem_to_id: Dict["Component", str] = {}
|
||||
|
||||
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
|
||||
r"""
|
||||
Adds elements to manager.
|
||||
"""
|
||||
for elem_name, elem in elem_dict.items():
|
||||
elem_id = "{}.{}".format(tab_name, elem_name)
|
||||
self._id_to_elem[elem_id] = elem
|
||||
self._elem_to_id[elem] = elem_id
|
||||
|
||||
def get_elem_list(self) -> List["Component"]:
|
||||
r"""
|
||||
Returns the list of all elements.
|
||||
"""
|
||||
return list(self._id_to_elem.values())
|
||||
|
||||
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
|
||||
r"""
|
||||
Returns an iterator over all elements with their names.
|
||||
"""
|
||||
for elem_id, elem in self._id_to_elem.items():
|
||||
yield elem_id.split(".")[-1], elem
|
||||
|
||||
def get_elem_by_id(self, elem_id: str) -> "Component":
|
||||
r"""
|
||||
Gets element by id.
|
||||
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
return self._id_to_elem[elem_id]
|
||||
|
||||
def get_id_by_elem(self, elem: "Component") -> str:
|
||||
r"""
|
||||
Gets id by element.
|
||||
"""
|
||||
return self._elem_to_id[elem]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
r"""
|
||||
Gets the base elements that are commonly used.
|
||||
"""
|
||||
return {
|
||||
self._id_to_elem["top.lang"],
|
||||
self._id_to_elem["top.model_name"],
|
||||
self._id_to_elem["top.model_path"],
|
||||
self._id_to_elem["top.finetuning_type"],
|
||||
self._id_to_elem["top.adapter_path"],
|
||||
self._id_to_elem["top.quantization_bit"],
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
self._id_to_elem["top.visual_inputs"],
|
||||
}
|
||||
369
src/llamafactory/webui/runner.py
Normal file
369
src/llamafactory/webui/runner.py
Normal file
@@ -0,0 +1,369 @@
|
||||
import os
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
||||
os.kill(children.pid, signal.SIGABRT)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
if self.running:
|
||||
return ALERTS["err_conflict"][lang]
|
||||
|
||||
if not model_name:
|
||||
return ALERTS["err_no_model"][lang]
|
||||
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if not dataset:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if not from_preview and self.demo_mode:
|
||||
return ALERTS["err_demo"][lang]
|
||||
|
||||
if not from_preview and get_device_count() > 1:
|
||||
return ALERTS["err_device_count"][lang]
|
||||
|
||||
if do_train:
|
||||
stage = TRAINING_STAGES[get("train.training_stage")]
|
||||
reward_model = get("train.reward_model")
|
||||
if stage == "ppo" and not reward_model:
|
||||
return ALERTS["err_no_reward_model"][lang]
|
||||
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.trainer = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
torch_gc()
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
do_train=True,
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
learning_rate=float(get("train.learning_rate")),
|
||||
num_train_epochs=float(get("train.num_train_epochs")),
|
||||
max_samples=int(get("train.max_samples")),
|
||||
per_device_train_batch_size=get("train.batch_size"),
|
||||
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
|
||||
lr_scheduler_type=get("train.lr_scheduler_type"),
|
||||
max_grad_norm=float(get("train.max_grad_norm")),
|
||||
logging_steps=get("train.logging_steps"),
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
optim=get("train.optim"),
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
packing=get("train.packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
use_badam=get("train.use_badam"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||
plot_loss=True,
|
||||
)
|
||||
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
|
||||
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_alpha"] = get("train.lora_alpha")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("train.reward_model")
|
||||
]
|
||||
)
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
elif args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
||||
elif args["stage"] == "orpo":
|
||||
args["orpo_beta"] = get("train.orpo_beta")
|
||||
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||
|
||||
if args["use_galore"]:
|
||||
args["galore_rank"] = get("train.galore_rank")
|
||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
if args["use_badam"]:
|
||||
args["badam_mode"] = get("train.badam_mode")
|
||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
predict_with_generate=True,
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
args["do_predict"] = True
|
||||
else:
|
||||
args["do_eval"] = True
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
self.do_train, self.running_data = do_train, data
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
env = deepcopy(os.environ)
|
||||
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data):
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data):
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data):
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data):
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self):
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
|
||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
|
||||
while self.trainer is not None:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
||||
return_dict = {
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data: dict):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
return {output_box: error}
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
config_dict[elem_id] = value
|
||||
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||
|
||||
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
106
src/llamafactory/webui/utils.py
Normal file
106
src/llamafactory/webui/utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from yaml import safe_dump
|
||||
|
||||
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type != "lora":
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
else:
|
||||
return gr.Dropdown(interactive=True)
|
||||
|
||||
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
try:
|
||||
tools = json.loads(text)
|
||||
if tools:
|
||||
assert isinstance(tools, list)
|
||||
for tool in tools:
|
||||
if "name" not in tool:
|
||||
raise NotImplementedError("Name not found.")
|
||||
except NotImplementedError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except Exception:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
return cmd_text
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, "r", encoding="utf-8") as f:
|
||||
running_log = f.read()
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
with open(trainer_log_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
trainer_log.append(json.loads(line))
|
||||
|
||||
if len(trainer_log) != 0:
|
||||
latest_log = trainer_log[-1]
|
||||
percentage = latest_log["percentage"]
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
latest_log["current_steps"],
|
||||
latest_log["total_steps"],
|
||||
latest_log["elapsed_time"],
|
||||
latest_log["remaining_time"],
|
||||
)
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||
safe_dump(clean_cmd(args), f)
|
||||
|
||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||
Reference in New Issue
Block a user