Merge pull request #5581 from Kuangdd01/pixtral-patch

[WIP] Support Pixtral-12B

Former-commit-id: fcddf4ec5c2914f73e23eeda2dbf67b048246669
This commit is contained in:
hoshi-hiyouga
2024-10-29 22:29:10 +08:00
committed by GitHub
9 changed files with 129 additions and 6 deletions

View File

@@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length

View File

@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features

View File

@@ -448,6 +448,70 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError(
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@@ -610,6 +674,7 @@ PLUGINS = {
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
}

View File

@@ -935,6 +935,14 @@ _register_template(
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),

View File

@@ -1178,6 +1178,18 @@ register_model_group(
)
register_model_group(
models={
"Pixtral-12B-Chat": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group(
models={
"Qwen-1.8B": {

View File

@@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
@@ -113,6 +113,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
"llava_next",
"llava_next_video",
"paligemma",
"pixtral",
"video_llava",
]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
@@ -186,7 +187,7 @@ def patch_target_modules(
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
@@ -195,5 +196,7 @@ def patch_target_modules(
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules