support quantization in export model
Former-commit-id: f32500ae6edccab7d14df4c92467e15986866def
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
@@ -14,7 +18,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -36,7 +40,13 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
|
||||
def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
config_kwargs: Dict[str, Any],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
):
|
||||
if getattr(config, "quantization_config", None): # gptq or awq
|
||||
model_args.quantization_bit = None # remove bnb quantization
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
@@ -63,6 +73,16 @@ def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str,
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if finetuning_args.export_quantization_bit is not None: # gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
config_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=finetuning_args.export_quantization_bit,
|
||||
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
||||
)
|
||||
config_kwargs["device_map"] = "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
||||
|
||||
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||
if model_args.rope_scaling is not None:
|
||||
@@ -91,6 +111,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
))
|
||||
|
||||
|
||||
def get_quantization_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
"""
|
||||
if os.path.isfile(finetuning_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = finetuning_args.export_quantization_dataset
|
||||
else:
|
||||
data_path = finetuning_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = finetuning_args.export_quantization_maxlen
|
||||
|
||||
samples = []
|
||||
for _ in range(finetuning_args.export_quantization_nsamples):
|
||||
while True:
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
Reference in New Issue
Block a user