update get template

Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
This commit is contained in:
hiyouga
2024-09-04 22:36:20 +08:00
parent 5d85be31ca
commit af178cbcd1
17 changed files with 57 additions and 56 deletions

View File

@@ -27,6 +27,7 @@ from .mm_plugin import get_mm_plugin
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
@@ -344,28 +345,27 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name in ["llava", "paligemma", "qwen2_vl"]:
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if name is None:
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(name, None)
template = TEMPLATES.get(data_args.template, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
raise ValueError("Template {} does not exist.".format(data_args.template))
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
if template.replace_eos: