Merge branch 'main' into minicpmv

Former-commit-id: d8840ae416660e23f1d615ffd404f519360151d9
This commit is contained in:
Zhangchi Feng
2025-01-10 20:12:07 +08:00
committed by GitHub
41 changed files with 647 additions and 357 deletions

View File

@@ -15,9 +15,9 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
if TYPE_CHECKING:
@@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"

View File

@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

View File

@@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.misc import check_version
if TYPE_CHECKING:
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)

View File

@@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -26,11 +26,10 @@ from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.misc import get_current_device
from ...extras.misc import check_version, get_current_device
if TYPE_CHECKING:
@@ -118,15 +117,15 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("auto_gptq>=0.5.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
check_version("autoawq", mandatory=True)
if quant_method == QuantizationMethod.AQLM:
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
@@ -136,8 +135,8 @@ def configure_quantization(
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
@@ -154,10 +153,10 @@ def configure_quantization(
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
@@ -175,7 +174,7 @@ def configure_quantization(
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
check_version("bitsandbytes>=0.43.0", mandatory=True)
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
@@ -187,7 +186,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq")
check_version("hqq", mandatory=True)
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
@@ -199,6 +198,6 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq")
check_version("eetq", mandatory=True)
init_kwargs["quantization_config"] = EetqConfig()
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")