better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
hiyouga
2024-05-29 23:55:38 +08:00
parent f90c4ca672
commit 87aa332583
14 changed files with 303 additions and 193 deletions

View File

@@ -6,6 +6,7 @@ from numpy.typing import NDArray
from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir
@@ -44,13 +45,14 @@ class WebChatModel(ChatModel):
def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
elif not get("top.model_name"):
elif not model_name:
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
@@ -60,21 +62,10 @@ class WebChatModel(ChatModel):
yield error
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
finetuning_type=get("top.finetuning_type"),
model_name_or_path=model_path,
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -83,8 +74,16 @@ class WebChatModel(ChatModel):
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)
super().__init__(args)
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
super().__init__(args)
yield ALERTS["info_loaded"][lang]
def unload_model(self, data) -> Generator[str, None, None]: