support autogptq in llama board #246

Former-commit-id: fea01226703d1534b5cf511bcb6a49e73bc86ce1
This commit is contained in:
hiyouga
2023-12-16 16:31:30 +08:00
parent 04dc3f4614
commit 9f77e8b025
12 changed files with 123 additions and 65 deletions

View File

@@ -7,6 +7,7 @@ from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import (
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
PEFT_METHODS,
SUPPORTED_MODELS,
TRAINING_STAGES,
DownloadSource
@@ -77,8 +78,11 @@ def get_template(model_name: str) -> str:
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
if finetuning_type not in PEFT_METHODS:
return gr.update(value=[], choices=[], interactive=False)
adapters = []
if model_name and finetuning_type == "lora": # full and freeze have no adapter
if model_name and finetuning_type == "lora":
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir):
@@ -87,7 +91,7 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
):
adapters.append(adapter)
return gr.update(value=[], choices=adapters)
return gr.update(value=[], choices=adapters, interactive=True)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:

View File

@@ -21,8 +21,11 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except:
return gr.update(interactive=False)
if (
len(dataset) > 0

View File

@@ -10,6 +10,9 @@ if TYPE_CHECKING:
from llmtuner.webui.engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def save_model(
lang: str,
model_name: str,
@@ -18,6 +21,8 @@ def save_model(
finetuning_type: str,
template: str,
max_shard_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_dir: str
) -> Generator[str, None, None]:
error = ""
@@ -25,23 +30,32 @@ def save_model(
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif not adapter_path:
error = ALERTS["err_no_adapter"][lang]
elif not export_dir:
error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
if error:
gr.Warning(error)
yield error
return
if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
else:
adapter_name_or_path = None
args = dict(
model_name_or_path=model_path,
adapter_name_or_path=",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]),
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir,
export_size=max_shard_size
export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset
)
yield ALERTS["info_exporting"][lang]
@@ -51,9 +65,11 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_dir = gr.Textbox()
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
@@ -67,14 +83,18 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_name("top.finetuning_type"),
engine.manager.get_elem_by_name("top.template"),
max_shard_size,
export_quantization_bit,
export_quantization_dataset,
export_dir
],
[info_box]
)
return dict(
export_dir=export_dir,
max_shard_size=max_shard_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir,
export_btn=export_btn,
info_box=info_box
)

View File

@@ -20,7 +20,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, scale=5)
adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:

View File

@@ -94,7 +94,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3)
reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(

View File

@@ -432,11 +432,11 @@ LOCALES = {
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点"
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器"
}
},
"cmd_preview_btn": {
@@ -585,6 +585,36 @@ LOCALES = {
"label": "温度系数"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "单个模型文件的最大大小。"
}
},
"export_quantization_bit": {
"en": {
"label": "Export quantization bit.",
"info": "Quantizing the exported model."
},
"zh": {
"label": "导出量化等级",
"info": "量化导出模型。"
}
},
"export_quantization_dataset": {
"en": {
"label": "Export quantization dataset.",
"info": "The calibration dataset used for quantization."
},
"zh": {
"label": "导出量化数据集",
"info": "量化过程中使用的校准数据集。"
}
},
"export_dir": {
"en": {
"label": "Export dir",
@@ -595,16 +625,6 @@ LOCALES = {
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"