support loftq

Former-commit-id: e7ac2eb7f7daae17525a278ffbe2f82c0fbd8093
This commit is contained in:
hiyouga
2023-12-12 22:47:06 +08:00
parent d9a50bf93f
commit e39bbdd287
5 changed files with 42 additions and 19 deletions

View File

@@ -144,28 +144,32 @@ def load_model_and_tokenizer(
model_args.quantization_bit = None
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
# Quantization configurations (using bitsandbytes library)
# Quantization configurations (using bitsandbytes)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if finetuning_args.loftq_init:
require_version("peft>=0.7.1.dev0", "To fix: pip install git+https://github.com/hiyouga/peft.git")
logger.info("Skip bnb quantization because using loftq.")
else:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(