reenable sdpa and fast tok by default
Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
This commit is contained in:
@@ -1,16 +1,23 @@
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def _is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def _get_package_version(name: str) -> str:
|
||||
def _get_package_version(name: str) -> "Version":
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
return version.parse(importlib.metadata.version(name))
|
||||
except Exception:
|
||||
return "0.0.0"
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
@@ -18,7 +25,7 @@ def is_fastapi_availble():
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
@@ -49,6 +56,10 @@ def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_sdpa_available():
|
||||
return _get_package_version("torch") > version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class ModelArguments:
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=False,
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||
)
|
||||
resize_vocab: bool = field(
|
||||
@@ -61,9 +61,9 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention for faster training."},
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
default=False,
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers.utils.versions import require_version
|
||||
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
|
||||
|
||||
@@ -62,18 +62,45 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
|
||||
|
||||
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if model_args.flash_attn:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", "flash_attention_2")
|
||||
else:
|
||||
setattr(config, "_attn_implementation", "flash_attention_2")
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||
else:
|
||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||
|
||||
|
||||
def _print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", None)
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
if attn_implementation == "flash_attention_2":
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif attn_implementation == "sdpa":
|
||||
logger.info("Using torch SDPA for faster training and inference.")
|
||||
else:
|
||||
logger.info("Using vanilla Attention implementation.")
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
@@ -365,6 +392,8 @@ def patch_model(
|
||||
|
||||
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
|
||||
|
||||
_print_attn_implementation(model.config)
|
||||
|
||||
try:
|
||||
model.add_model_tags(["llama-factory"])
|
||||
except Exception:
|
||||
|
||||
@@ -33,7 +33,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
|
||||
@@ -67,7 +67,7 @@ class Runner:
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if len(dataset) == 0:
|
||||
if not dataset:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if not from_preview and self.demo_mode:
|
||||
@@ -122,7 +122,7 @@ class Runner:
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn=(get("top.booster") == "flashattn"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
|
||||
Reference in New Issue
Block a user