add pixtral template

Former-commit-id: e0bcaa6c6e902e29361438a6d215bbc2535b648f
This commit is contained in:
Kingsley
2024-09-26 12:11:58 +08:00
parent de72d1f0e7
commit 300feb3245
4 changed files with 61 additions and 0 deletions

View File

@@ -323,6 +323,12 @@ class PaliGemmaPlugin(BasePlugin):
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class PixtralPlugin(BasePlugin):
#TODO preprocess according to Pixtral hf
from transformers import LlavaForConditionalGeneration
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
pass
class Qwen2vlPlugin(BasePlugin):
@override

View File

@@ -821,6 +821,13 @@ _register_template(
replace_eos=True,
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]")
)
_register_template(
name="qwen",