[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.constants import FILEEXT2TYPE, QuantizationMethod
|
||||
from ...extras.misc import check_version, get_current_device
|
||||
|
||||
|
||||
@@ -41,19 +40,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
|
||||
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
@@ -145,7 +131,7 @@ def configure_quantization(
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_method == QuantizationMethod.BNB:
|
||||
if model_args.quantization_bit == 8:
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@@ -173,7 +159,7 @@ def configure_quantization(
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
@@ -185,7 +171,7 @@ def configure_quantization(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user