fix inputs

Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
This commit is contained in:
hiyouga
2024-11-23 18:25:45 +00:00
parent cd2485f28d
commit 5003820a6a
14 changed files with 148 additions and 95 deletions

View File

@@ -26,7 +26,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
@@ -159,27 +159,25 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
elif model_type == "mllama":
image_seqlen = (
(config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1
) * config.vision_config.max_num_tiles
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy