support mixtral
Former-commit-id: 75b5b8e36ab1933b2625f11b645f56cbc805fd85
This commit is contained in:
@@ -25,7 +25,6 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.model.adapter import init_adapter
|
||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
||||
@@ -38,7 +37,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("transformers>=4.36.0", "To fix: pip install transformers>=4.36.0")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||
@@ -124,28 +123,22 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if is_flash_attn2_available():
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) == "qwen":
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
setattr(config, "attn_implementation", "flash_attention_2")
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
|
||||
# if getattr(config, "model_type", None) == "llama":
|
||||
# setattr(config, "group_size_ratio", 0.25)
|
||||
# logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
# else:
|
||||
# logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using gptq or awq)
|
||||
if getattr(config, "quantization_config", None):
|
||||
|
||||
Reference in New Issue
Block a user