[model] add qwen2.5 vl models (#6779)
Former-commit-id: ed46fb4f6194c30060b908092464dded12e5787c
This commit is contained in:
@@ -176,7 +176,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if torch.is_floating_point(value): # cast data dtype for paligemma
|
||||
value = value.to(model.dtype)
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
if key == "second_per_grid_ts": # qwen2.5vl special case
|
||||
gen_kwargs[key] = value.tolist()
|
||||
else:
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||
gen_kwargs["input_ids"] = inputs
|
||||
|
||||
@@ -135,12 +135,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
|
||||
input_ids=features["input_ids"],
|
||||
image_grid_thw=mm_inputs.get("image_grid_thw", None),
|
||||
video_grid_thw=mm_inputs.get("video_grid_thw", None),
|
||||
attention_mask=features["attention_mask"],
|
||||
)
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": features["attention_mask"],
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs:
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
||||
@@ -178,16 +178,16 @@ class BasePlugin:
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
||||
)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
image_resolution=getattr(processor, "video_resolution", 256 * 256),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
@@ -501,7 +501,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
||||
)
|
||||
if "valid_image_nums_ls" in kwargs:
|
||||
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
||||
@@ -521,9 +521,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
image_resolution=getattr(processor, "video_resolution", 256 * 256),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
@@ -610,7 +610,7 @@ class MllamaPlugin(BasePlugin):
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
imglens: List[int] = kwargs["imglens"]
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
|
||||
batch_images = []
|
||||
for image_length in imglens:
|
||||
batch_images.append(images[:image_length])
|
||||
@@ -875,7 +875,15 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
|
||||
video_fps = getattr(processor, "video_fps", 2.0)
|
||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len(
|
||||
mm_inputs["video_grid_thw"]
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class VideoLlavaPlugin(BasePlugin):
|
||||
|
||||
@@ -1928,6 +1928,14 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct",
|
||||
},
|
||||
"Qwen2.5-7B-Instruct-1M": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
},
|
||||
"Qwen2.5-14B-Instruct-1M": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-1M",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-1M",
|
||||
},
|
||||
"Qwen2.5-0.5B-Instruct-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
|
||||
@@ -2149,6 +2157,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/QVQ-72B-Preview",
|
||||
DownloadSource.MODELSCOPE: "Qwen/QVQ-72B-Preview",
|
||||
},
|
||||
"Qwen2.5-VL-3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
},
|
||||
"Qwen2.5-VL-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
},
|
||||
"Qwen2.5-VL-72B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
},
|
||||
},
|
||||
template="qwen2_vl",
|
||||
vision=True,
|
||||
|
||||
@@ -59,19 +59,19 @@ class ProcessorArguments:
|
||||
"""
|
||||
|
||||
image_resolution: int = field(
|
||||
default=512 * 512,
|
||||
metadata={"help": "Keeps the number of pixels of image below this resolution."},
|
||||
default=768 * 768,
|
||||
metadata={"help": "The maximum number of pixels of image inputs."},
|
||||
)
|
||||
video_resolution: int = field(
|
||||
default=128 * 128,
|
||||
metadata={"help": "Keeps the number of pixels of video below this resolution."},
|
||||
default=256 * 256,
|
||||
metadata={"help": "The maximum number of pixels of video inputs."},
|
||||
)
|
||||
video_fps: float = field(
|
||||
default=2.0,
|
||||
metadata={"help": "The frames to sample per second for video inputs."},
|
||||
)
|
||||
video_maxlen: int = field(
|
||||
default=64,
|
||||
default=128,
|
||||
metadata={"help": "The maximum number of sampled frames for video inputs."},
|
||||
)
|
||||
|
||||
|
||||
@@ -286,3 +286,11 @@ _register_composite_model(
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_5_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user