release v0.5.3

Former-commit-id: f6bc89581b3cd129448da2defc23848de6f494ed
This commit is contained in:
hiyouga
2024-02-29 00:34:19 +08:00
parent a2c881fa08
commit 544e7a491b
10 changed files with 116 additions and 67 deletions

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import FinetuningArguments, ModelArguments
from ..hparams import ModelArguments
logger = get_logger(__name__)
@@ -157,7 +157,7 @@ def _configure_quantization(
config_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled():
@@ -167,7 +167,15 @@ def _configure_quantization(
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
if quantization_config.get("quant_method", None) == "aqlm":
quantization_config["bits"] = 2
logger.info(
"Loading {}-bit {}-quantized model.".format(
quantization_config.get("bits", "?"), quantization_config.get("quant_method", None)
)
)
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
@@ -253,7 +261,6 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -274,9 +281,6 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
if finetuning_args.use_dora:
config_kwargs["device_map"] = {"": get_current_device()}
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool