update get template
Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
@@ -27,7 +27,6 @@ from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -179,9 +178,6 @@ def _get_preprocessed_dataset(
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
if data_args.dataset_map_batch_size:
|
||||
# Set the batch size conditionally without considering the default variable of the batch size in the map function
|
||||
kwargs.update(batch_size=data_args.dataset_map_batch_size)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
@@ -205,17 +201,14 @@ def _get_preprocessed_dataset(
|
||||
|
||||
|
||||
def get_dataset(
|
||||
template: "Template",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Tuple["DatasetModule", "Template"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
) -> "DatasetModule":
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
@@ -233,7 +226,7 @@ def get_dataset(
|
||||
if data_args.streaming:
|
||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||
|
||||
return dataset_module, template
|
||||
return dataset_module
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
@@ -280,7 +273,8 @@ def get_dataset(
|
||||
dataset_module = {}
|
||||
if "train" in dataset_dict:
|
||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
|
||||
return dataset_module, template
|
||||
return dataset_module
|
||||
|
||||
@@ -27,6 +27,7 @@ from .mm_plugin import get_mm_plugin
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .formatter import SLOTS, Formatter
|
||||
from .mm_plugin import BasePlugin
|
||||
|
||||
@@ -344,28 +345,27 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
return jinja_template
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
tool_format: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name in ["llava", "paligemma", "qwen2_vl"]:
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
|
||||
if name is None:
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = TEMPLATES.get(name, None)
|
||||
template = TEMPLATES.get(data_args.template, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
||||
|
||||
if tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(tool_format))
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
|
||||
Reference in New Issue
Block a user