fix mixed mm inputs and rlhf-v

Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent 1d8e9c7897
commit 7e4c5d4bb3
20 changed files with 306 additions and 277 deletions

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
@@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name == "qwen2_vl":
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if name is None:
template = TEMPLATES["empty"] # placeholder
else:
@@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words
if template.replace_eos: