[model] support yarn (#6693)

Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
hoshi-hiyouga
2025-01-18 13:56:09 +08:00
committed by GitHub
parent e4046bdd1f
commit 87d685b59f
11 changed files with 84 additions and 64 deletions

View File

@@ -16,12 +16,14 @@ import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from transformers.utils import is_torch_npu_available
from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import QUANTIZATION_BITS, get_save_dir
from .common import get_save_dir, load_config
from .locales import ALERTS
@@ -59,6 +61,8 @@ class WebChatModel(ChatModel):
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
user_config = load_config()
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
@@ -74,26 +78,22 @@ class WebChatModel(ChatModel):
yield error
return
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)
# checkpoints
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
@@ -102,6 +102,12 @@ class WebChatModel(ChatModel):
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
# quantization
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
super().__init__(args)
yield ALERTS["info_loaded"][lang]