better llamaboard
* easily resume from checkpoint * support full and freeze checkpoints * faster ui Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
@@ -6,6 +6,7 @@ from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.constants import PEFT_METHODS
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_save_dir
|
||||
@@ -44,13 +45,14 @@ class WebChatModel(ChatModel):
|
||||
|
||||
def load_model(self, data) -> Generator[str, None, None]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
elif not get("top.model_name"):
|
||||
elif not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif self.demo_mode:
|
||||
error = ALERTS["err_demo"][lang]
|
||||
@@ -60,21 +62,10 @@ class WebChatModel(ChatModel):
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
model_name_or_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
@@ -83,8 +74,16 @@ class WebChatModel(ChatModel):
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
if checkpoint_path:
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
|
||||
)
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||
|
||||
super().__init__(args)
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data) -> Generator[str, None, None]:
|
||||
|
||||
Reference in New Issue
Block a user