update webui
Former-commit-id: da30d0fb4abdb825f3383ddd106bb06a84695b7a
This commit is contained in:
@@ -6,7 +6,7 @@ import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -48,20 +48,10 @@ def get_model_path(model_name: str) -> str:
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
|
||||
|
||||
def get_template(
|
||||
model_name: str,
|
||||
) -> str:
|
||||
if model_name == "Custom":
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
# get last dir
|
||||
basename = os.path.basename(model_name_or_path)
|
||||
# prefix match
|
||||
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
|
||||
if basename.startswith(k):
|
||||
return v
|
||||
return "default"
|
||||
|
||||
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -4,7 +4,7 @@ import gradio as gr
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config, get_template
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,10 +36,11 @@ def create_top() -> Dict[str, "Component"]:
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
).then(
|
||||
get_template, [model_name], [template]
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, [lang, model_name, model_path])
|
||||
model_path.change(get_template, [model_name], [template])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
|
||||
@@ -3,10 +3,10 @@ from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.extras.constants import STAGES
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
from llmtuner.extras.constants import STAGES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
@@ -15,9 +15,7 @@ if TYPE_CHECKING:
|
||||
|
||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
stage = gr.Dropdown(choices=STAGES,
|
||||
value="Supervised Finetuning", scale=2)
|
||||
|
||||
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
@@ -104,7 +102,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
stage,
|
||||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
@@ -145,7 +143,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
||||
)
|
||||
|
||||
return dict(
|
||||
stage=stage,
|
||||
training_stage=training_stage,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
data_preview_btn=data_preview_btn,
|
||||
|
||||
@@ -87,6 +87,16 @@ LOCALES = {
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
"info": "The stage to perform in training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练阶段",
|
||||
"info": "目前采用的训练方式。"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
@@ -343,16 +353,6 @@ LOCALES = {
|
||||
"label": "RLHF 参数设置"
|
||||
}
|
||||
},
|
||||
"rlhf_method": {
|
||||
"en": {
|
||||
"label": "RLHF method",
|
||||
"info": "The RLHF algorithm to adopt."
|
||||
},
|
||||
"zh": {
|
||||
"label": "RLHF 方法",
|
||||
"info": "RLHF 阶段使用的算法。"
|
||||
}
|
||||
},
|
||||
"dpo_beta": {
|
||||
"en": {
|
||||
"label": "DPO beta",
|
||||
@@ -546,15 +546,7 @@ LOCALES = {
|
||||
"zh": {
|
||||
"value": "开始导出"
|
||||
}
|
||||
},
|
||||
"stage": {
|
||||
"en": {
|
||||
"label": "train stage"
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练阶段"
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, get_template
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
@@ -70,7 +70,7 @@ class Runner:
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
stage: str,
|
||||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
@@ -138,21 +138,21 @@ class Runner:
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if stage == "Pretraining":
|
||||
args["stage"] = "pt"
|
||||
if stage == "Reward Modeling":
|
||||
if training_stage == "Reward Modeling":
|
||||
args["stage"] = "rm"
|
||||
args["resume_lora_training"] = False
|
||||
elif stage == "PPO":
|
||||
elif training_stage == "PPO":
|
||||
args["stage"] = "ppo"
|
||||
args["resume_lora_training"] = False
|
||||
args["reward_model"] = reward_model
|
||||
args["padding_side"] = "left"
|
||||
val_size = 0
|
||||
elif stage == "DPO":
|
||||
elif training_stage == "DPO":
|
||||
args["stage"] = "dpo"
|
||||
args["resume_lora_training"] = False
|
||||
args["dpo_beta"] = dpo_beta
|
||||
elif training_stage == "Pre-Training":
|
||||
args["stage"] = "pt"
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
|
||||
Reference in New Issue
Block a user