refactor mm training

Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 77c2c7076b
commit c62a6ca59d
29 changed files with 499 additions and 312 deletions

View File

@@ -15,9 +15,11 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import BasePlugin, get_mm_plugin
if TYPE_CHECKING:
@@ -41,11 +43,9 @@ class Template:
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
vision_start_token: str
vision_end_token: str
efficient_eos: bool
replace_eos: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
self,
@@ -207,11 +207,9 @@ def _register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
image_token: str = "<image>",
vision_start_token: str = "<|vision_start|>",
vision_end_token: str = "<|vision_end|>",
efficient_eos: bool = False,
replace_eos: bool = False,
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
) -> None:
r"""
Registers a chat template.
@@ -258,11 +256,9 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
vision_start_token=vision_start_token,
vision_end_token=vision_end_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
mm_plugin=mm_plugin,
)
@@ -722,6 +718,17 @@ _register_template(
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -766,6 +773,19 @@ _register_template(
)
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -790,17 +810,15 @@ _register_template(
_register_template(
name="qwen2vl",
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
image_token="<|image_pad|>",
vision_start_token="<|vision_start|>",
vision_end_token="<|vision_end|>",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
)
@@ -915,6 +933,7 @@ _register_template(
),
stop_words=["###"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)