add loftq

Former-commit-id: 0b900882ef19ac49604a24fbae8b3254f1bff7ad
This commit is contained in:
hiyouga
2023-12-14 21:53:56 +08:00
parent c32303fc7e
commit 27ef5b1aa7
3 changed files with 90 additions and 14 deletions

View File

@@ -119,16 +119,6 @@ def load_model_and_tokenizer(
model_args.rope_scaling, scaling_factor
))
# Set FlashAttention-2
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Current model automatically enables FlashAttention if installed.")
else:
setattr(config, "attn_implementation", "flash_attention_2")
logger.info("Using FlashAttention-2 for faster training and inference.")
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
@@ -138,10 +128,19 @@ def load_model_and_tokenizer(
# else:
# logger.warning("Current model does not support shift short attention.")
# Set FlashAttention-2
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Current model automatically enables FlashAttention if installed.")
else:
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
# Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None):
if model_args.quantization_bit is not None: # remove bnb quantization
model_args.quantization_bit = None
model_args.quantization_bit = None # remove bnb quantization
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))