fix mixed mm inputs and rlhf-v

Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent 1d8e9c7897
commit 7e4c5d4bb3
20 changed files with 306 additions and 277 deletions

View File

@@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
except Exception:
processor = None

View File

@@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for quantized VLMs.
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook(
@@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
if getattr(config, "model_type", None) == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif getattr(config, "model_type", None) == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
image_seqlen = -1
return image_seqlen
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: