[model] add llama4 (#7611)
This commit is contained in:
@@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama4Plugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // processor.patch_size)
|
||||
* (image_width // processor.patch_size)
|
||||
// processor.downsample_ratio
|
||||
)
|
||||
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||
if self.expand_mm_tokens:
|
||||
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
||||
new_content = []
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
new_content.append(split_part)
|
||||
if local_image_index < placeholder_count:
|
||||
tokens_for_this_image = processor._prompt_split_image(
|
||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||
)
|
||||
num_image_tokens += 1
|
||||
new_content.append(tokens_for_this_image)
|
||||
|
||||
content = "".join(new_content)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("aspect_ratios", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"llama4": Llama4Plugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
|
||||
@@ -968,6 +968,26 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama4",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
|
||||
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot|>", "<|eom|>"],
|
||||
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
register_template(
|
||||
name="mllama",
|
||||
|
||||
Reference in New Issue
Block a user