improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 73f4513c84
commit 46f99ff277
12 changed files with 165 additions and 169 deletions

View File

@@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
from .utils import QuantizationMethod
if TYPE_CHECKING:
@@ -173,10 +174,10 @@ def _configure_quantization(
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == "gptq":
if quant_method == QuantizationMethod.GPTQ:
quantization_config["use_exllama"] = False # disable exllama
if quant_method == "aqlm":
if quant_method == QuantizationMethod.AQLM:
require_version(
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
@@ -205,7 +206,7 @@ def _configure_quantization(
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")