[data] gemma3 plugin pan and scan (#7294)
* gemma3 pan and scan * add test case * fix test
This commit is contained in:
@@ -290,7 +290,18 @@ class MMPluginMixin:
|
||||
if imglens is not None:
|
||||
images = _make_batched_images(images, imglens)
|
||||
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
image_processor_kwargs = {}
|
||||
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
|
||||
image_processor_kwargs.update(
|
||||
{
|
||||
"do_pan_and_scan": True,
|
||||
"pan_and_scan_min_crop_size": 256,
|
||||
"pan_and_scan_max_num_crops": 4,
|
||||
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||
}
|
||||
)
|
||||
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
@@ -401,10 +412,23 @@ class Gemma3Plugin(BasePlugin):
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
||||
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
||||
|
||||
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
|
||||
if do_pan_and_scan:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
if do_pan_and_scan:
|
||||
image_placeholder_str = (
|
||||
"Here is the original image {{image}} and here are some crops to help you see better "
|
||||
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
|
||||
)
|
||||
else:
|
||||
image_placeholder_str = "{{image}}"
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", image_str)
|
||||
|
||||
@@ -1263,6 +1263,7 @@ register_template(
|
||||
format_user=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
@@ -1277,6 +1278,7 @@ register_template(
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<end_of_turn>"],
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -218,6 +218,10 @@ class ProcessorArguments:
|
||||
default=32 * 32,
|
||||
metadata={"help": "The minimum number of pixels of image inputs."},
|
||||
)
|
||||
image_do_pan_and_scan: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use pan and scan to process image for gemma3."},
|
||||
)
|
||||
video_max_pixels: int = field(
|
||||
default=256 * 256,
|
||||
metadata={"help": "The maximum number of pixels of video inputs."},
|
||||
@@ -235,6 +239,13 @@ class ProcessorArguments:
|
||||
metadata={"help": "The maximum number of sampled frames for video inputs."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_max_pixels < self.image_min_pixels:
|
||||
raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
|
||||
|
||||
if self.video_max_pixels < self.video_min_pixels:
|
||||
raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
@@ -342,6 +353,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
def __post_init__(self):
|
||||
BaseModelArguments.__post_init__(self)
|
||||
ProcessorArguments.__post_init__(self)
|
||||
ExportArguments.__post_init__(self)
|
||||
VllmArguments.__post_init__(self)
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
@@ -72,6 +72,7 @@ def patch_processor(
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
|
||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
|
||||
Reference in New Issue
Block a user