[model] add gemma3n (#8509)
This commit is contained in:
@@ -388,7 +388,7 @@ class MMPluginMixin:
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@@ -509,6 +509,36 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Gemma3nPlugin(Gemma3Plugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
||||
full_audio_sequence: str = getattr(processor, "full_audio_sequence")
|
||||
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
||||
audio_str = full_audio_sequence if self.expand_mm_tokens else boi_token
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1845,6 +1875,7 @@ PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"gemma3n": Gemma3nPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
|
||||
@@ -984,6 +984,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gemma3n",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<end_of_turn>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
|
||||
Reference in New Issue
Block a user