[model] support yarn (#6693)
Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
@@ -16,12 +16,14 @@ import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.constants import PEFT_METHODS
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import QUANTIZATION_BITS, get_save_dir
|
||||
from .common import get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@@ -59,6 +61,8 @@ class WebChatModel(ChatModel):
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
|
||||
user_config = load_config()
|
||||
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
@@ -74,26 +78,22 @@ class WebChatModel(ChatModel):
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# checkpoints
|
||||
if checkpoint_path:
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
@@ -102,6 +102,12 @@ class WebChatModel(ChatModel):
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
|
||||
super().__init__(args)
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user