remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
This commit is contained in:
@@ -90,7 +90,6 @@ class WebChatModel(ChatModel):
|
||||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
|
||||
@@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str:
|
||||
return model_name.split("-")[0]
|
||||
|
||||
|
||||
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
|
||||
def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
r"""
|
||||
Gets the necessary information of this model.
|
||||
|
||||
Returns:
|
||||
model_path (str)
|
||||
template (str)
|
||||
visual (bool)
|
||||
"""
|
||||
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
|
||||
return get_model_path(model_name), get_template(model_name)
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_visual
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
@@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
engine.manager.get_elem_by_id("top.model_name").change(
|
||||
lambda model_name: gr.Column(visible=get_visual(model_name)),
|
||||
[engine.manager.get_elem_by_id("top.model_name")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
|
||||
@@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]:
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
|
||||
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
)
|
||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||
@@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
|
||||
@@ -183,20 +183,6 @@ LOCALES = {
|
||||
"label": "부스터",
|
||||
},
|
||||
},
|
||||
"visual_inputs": {
|
||||
"en": {
|
||||
"label": "Visual inputs",
|
||||
},
|
||||
"ru": {
|
||||
"label": "визуальные входы",
|
||||
},
|
||||
"zh": {
|
||||
"label": "图像输入",
|
||||
},
|
||||
"ko": {
|
||||
"label": "시각적 입력",
|
||||
},
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
|
||||
@@ -75,5 +75,4 @@ class Manager:
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
self._id_to_elem["top.visual_inputs"],
|
||||
}
|
||||
|
||||
@@ -116,7 +116,6 @@ class Runner:
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
@@ -252,7 +251,6 @@ class Runner:
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
eval_dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
|
||||
Reference in New Issue
Block a user