[model] switch to gptqmodel (#8108)
This commit is contained in:
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
if model_args.flash_attn != AttentionFunction.FA2:
|
||||
|
||||
@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||
return
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
if model_type in [
|
||||
"dbrx",
|
||||
"granitemoe",
|
||||
"jamba",
|
||||
"jetmoe",
|
||||
"llama4",
|
||||
"mixtral",
|
||||
"olmoe",
|
||||
"phimoe",
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
if model_type in [
|
||||
"dbrx",
|
||||
"granitemoe",
|
||||
"jamba",
|
||||
"jetmoe",
|
||||
"llama4",
|
||||
"mixtral",
|
||||
"olmoe",
|
||||
"phimoe",
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
]:
|
||||
setattr(config, "output_router_logits", True)
|
||||
|
||||
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif model_type == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
elif model_type == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif model_type == "jetmoe":
|
||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
elif model_type == "jetmoe":
|
||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
@@ -97,7 +97,7 @@ def configure_quantization(
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
check_version("gptqmodel>=2.0.0", mandatory=True)
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
@@ -111,12 +111,12 @@ def configure_quantization(
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
elif model_args.export_quantization_bit is not None: # gptqmodel
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
check_version("optimum>=1.17.0", mandatory=True)
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
check_version("optimum>=1.24.0", mandatory=True)
|
||||
check_version("gptqmodel>=2.0.0", mandatory=True)
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
@@ -142,7 +142,8 @@ def configure_quantization(
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BNB:
|
||||
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if model_args.rope_scaling is None:
|
||||
return
|
||||
|
||||
@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
old_max_length = getattr(config, "max_position_embeddings", None)
|
||||
else:
|
||||
logger.warning_rank0("Cannot find the max position embeddings in the config.")
|
||||
return
|
||||
|
||||
if model_args.model_max_length is not None: # training
|
||||
if model_args.model_max_length <= old_max_length:
|
||||
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
|
||||
return
|
||||
|
||||
if model_args.rope_scaling == RopeScaling.DYNAMIC:
|
||||
logger.warning_rank0(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if (not current_max_length) or model_args.model_max_length <= current_max_length:
|
||||
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
|
||||
return
|
||||
rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
|
||||
else: # inference
|
||||
rope_factor = 2.0
|
||||
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
elif model_args.rope_scaling == RopeScaling.LLAMA3:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
rope_kwargs["low_freq_factor"] = 1.0
|
||||
rope_kwargs["high_freq_factor"] = 4.0
|
||||
else:
|
||||
rope_kwargs["factor"] = 2.0
|
||||
rope_kwargs = {
|
||||
"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum
|
||||
"factor": rope_factor,
|
||||
}
|
||||
setattr(config, "max_position_embeddings", old_max_length * rope_factor)
|
||||
logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
|
||||
|
||||
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
|
||||
rope_kwargs["original_max_position_embeddings"] = old_max_length
|
||||
elif model_args.rope_scaling == RopeScaling.LLAMA3:
|
||||
rope_kwargs["original_max_position_embeddings"] = old_max_length
|
||||
rope_kwargs["low_freq_factor"] = 1.0
|
||||
rope_kwargs["high_freq_factor"] = 4.0
|
||||
|
||||
setattr(config, "rope_scaling", rope_kwargs)
|
||||
logger.info_rank0(
|
||||
|
||||
Reference in New Issue
Block a user