[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"find_all_linear_modules",
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]

View File

@@ -81,9 +81,8 @@ def _setup_freeze_tuning(
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
f"`num_layers` {num_layers} should be "
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
stride = num_layers // finetuning_args.freeze_trainable_layers
@@ -178,7 +177,7 @@ def _setup_lora_tuning(
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
@@ -263,8 +262,7 @@ def init_adapter(
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
r"""Initialize the adapters.
Support full-parameter, freeze and LoRA training.

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
@@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
@@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
r"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
@@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
r"""Load model config."""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
@@ -124,9 +120,7 @@ def load_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
@@ -194,8 +188,9 @@ def load_model(
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
param_stats = (
f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
)
else:
param_stats = f"all params: {all_param:,}"

View File

@@ -21,7 +21,7 @@
import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
r"""Saves VRAM by smartly offloading to RAM."""
@staticmethod
@torch.cuda.amp.custom_fwd
@@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
r"""Only applies gradient checkpointing to trainable layers."""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
@@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def _gradient_checkpointing_enable(
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""
Activates gradient checkpointing for the current model.
r"""Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
@@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
r"""Prepare the model before training.
Include:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.")

View File

@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
r"""Resize token embeddings."""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore

View File

@@ -18,7 +18,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
value_states,
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
from ...extras import logging
from .visual import COMPOSITE_MODELS
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
if model_type == "chatglm":
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r"""Find the modules in the expanded blocks to apply lora."""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
@@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
if not is_deepspeed_zero3_enabled():
return

View File

@@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Gets the sequnce lengths in the current batch.
r"""Get the sequnce lengths in the current batch.
e.g.
```python
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r"""Prepare the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()

View File

@@ -19,7 +19,7 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any
import torch
from datasets import load_dataset
@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1
if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
@@ -101,11 +97,9 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
@@ -113,7 +107,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.misc import get_current_device
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {
"model": model,
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
r"""Load peft model with unsloth. Used in both training and inference."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import torch
from transformers.utils import cached_file
@@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
r"""Load value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""

View File

@@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import transformers
@@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class CompositeModel:
model_type: str
projector_key: str
vision_model_keys: List[str]
language_model_keys: List[str]
lora_conflict_keys: List[str]
vision_model_keys: list[str]
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
@@ -51,15 +52,15 @@ class CompositeModel:
return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None,
lora_conflict_keys: Optional[List[str]] = None,
vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
):
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
@@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
@@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
r"""Patch VLMs before loading them."""
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
@@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in COMPOSITE_MODELS:
@@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
r"""Compute the number of special tokens per image."""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
@@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
r"""Compute the patch size of the vit."""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
r"""Get the vision_feature_select_strategy."""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
@@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> List[str]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
) -> list[str]:
r"""Freezes vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
@@ -93,7 +93,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32