[model] switch to gptqmodel (#8108)

This commit is contained in:
hoshi-hiyouga
2025-05-19 22:25:40 +08:00
committed by GitHub
parent bc7f00f2c7
commit 45030ff803
9 changed files with 78 additions and 62 deletions

View File

@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2:

View File

@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", is_trainable)
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)

View File

@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("optimum>=1.24.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
@@ -142,7 +142,8 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB:

View File

@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.rope_scaling is None:
return
@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.")
return
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
if hasattr(config, "max_position_embeddings"):
old_max_length = getattr(config, "max_position_embeddings", None)
else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")
return
if model_args.model_max_length is not None: # training
if model_args.model_max_length <= old_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
if model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if (not current_max_length) or model_args.model_max_length <= current_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
else: # inference
rope_factor = 2.0
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else:
rope_kwargs["factor"] = 2.0
rope_kwargs = {
"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum
"factor": rope_factor,
}
setattr(config, "max_position_embeddings", old_max_length * rope_factor)
logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["original_max_position_embeddings"] = old_max_length
elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = old_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0(