reenable sdpa and fast tok by default

Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
This commit is contained in:
hiyouga
2024-04-24 02:18:44 +08:00
parent 35c4a2c212
commit d2bb1b3a6b
8 changed files with 64 additions and 27 deletions

View File

@@ -15,7 +15,7 @@ from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.packages import is_flash_attn2_available, is_sdpa_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
@@ -62,18 +62,45 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
return
logger.info("Using FlashAttention-2 for faster training and inference.")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2")
else:
setattr(config, "_attn_implementation", "flash_attention_2")
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
setattr(config, "_attn_implementation", "eager")
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def _print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla Attention implementation.")
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@@ -365,6 +392,8 @@ def patch_model(
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
_print_attn_implementation(model.config)
try:
model.add_model_tags(["llama-factory"])
except Exception: