optimize aqlm training

Former-commit-id: 8b42660e4039b3d6475f502f397686ba6b140627
This commit is contained in:
hiyouga
2024-03-05 18:35:41 +08:00
parent 3553e301dd
commit a10bead9b5
2 changed files with 28 additions and 13 deletions

View File

@@ -159,25 +159,25 @@ def _configure_quantization(
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quant_method = quantization_config.get("quant_method", "")
if quant_method == "gptq":
quantization_config["use_exllama"] = False # disable exllama
if quantization_config.get("quant_method", None) == "aqlm":
if quant_method == "aqlm":
require_version(
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
quantization_config["bits"] = 2
logger.info(
"Loading {}-bit {}-quantized model.".format(
quantization_config.get("bits", "?"), str(quantization_config.get("quant_method", "")).upper()
)
)
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
@@ -213,6 +213,7 @@ def _configure_quantization(
bnb_4bit_quant_type=model_args.quantization_type,
)
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@@ -285,10 +286,13 @@ def patch_config(
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = True
if is_trainable:
init_kwargs["device_map"] = {"": get_current_device()}
elif model_args.export_dir is None:
init_kwargs["device_map"] = "auto"
if "device_map" not in init_kwargs:
if is_trainable:
init_kwargs["device_map"] = {"": get_current_device()}
elif model_args.export_dir is None:
init_kwargs["device_map"] = "auto"
else:
init_kwargs["device_map"] = {"": "cpu"}
def patch_model(