refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -15,9 +15,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import BasePlugin, get_mm_plugin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,11 +43,9 @@ class Template:
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
image_token: str
|
||||
vision_start_token: str
|
||||
vision_end_token: str
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
@@ -207,11 +207,9 @@ def _register_template(
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Sequence[str] = [],
|
||||
image_token: str = "<image>",
|
||||
vision_start_token: str = "<|vision_start|>",
|
||||
vision_end_token: str = "<|vision_end|>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@@ -258,11 +256,9 @@ def _register_template(
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
image_token=image_token,
|
||||
vision_start_token=vision_start_token,
|
||||
vision_end_token=vision_end_token,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
|
||||
@@ -722,6 +718,17 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
@@ -766,6 +773,19 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
@@ -790,17 +810,15 @@ _register_template(
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen2vl",
|
||||
name="qwen2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
image_token="<|image_pad|>",
|
||||
vision_start_token="<|vision_start|>",
|
||||
vision_end_token="<|vision_end|>",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
@@ -915,6 +933,7 @@ _register_template(
|
||||
),
|
||||
stop_words=["###"],
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user