support infer 4bit model on GPUs #3023

Former-commit-id: 950a9dab9055839990656b2b40956792b253573d
This commit is contained in:
hiyouga
2024-04-01 17:34:04 +08:00
parent 61eb3a3d46
commit e7f13098c6
2 changed files with 14 additions and 6 deletions

View File

@@ -208,11 +208,6 @@ def _configure_quantization(
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
@@ -227,7 +222,16 @@ def _configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
init_kwargs["device_map"] = {"": get_current_device()}
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))