fix inputs
Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
This commit is contained in:
@@ -26,7 +26,7 @@ from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -159,27 +159,25 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
else:
|
||||
image_seqlen = -1
|
||||
elif model_type == "mllama":
|
||||
image_seqlen = (
|
||||
(config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1
|
||||
) * config.vision_config.max_num_tiles
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig") -> int:
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
patch_size = getattr(config.vision_config, "patch_size", -1)
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
return vision_feature_select_strategy
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user