update get template

Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
This commit is contained in:
hiyouga
2024-09-04 22:36:20 +08:00
parent 5d85be31ca
commit af178cbcd1
17 changed files with 57 additions and 56 deletions

View File

@@ -14,7 +14,7 @@
import os
import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
@@ -27,7 +27,6 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
@@ -179,9 +178,6 @@ def _get_preprocessed_dataset(
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)
if data_args.dataset_map_batch_size:
# Set the batch size conditionally without considering the default variable of the batch size in the map function
kwargs.update(batch_size=data_args.dataset_map_batch_size)
dataset = dataset.map(
preprocess_func,
@@ -205,17 +201,14 @@ def _get_preprocessed_dataset(
def get_dataset(
template: "Template",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Tuple["DatasetModule", "Template"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
) -> "DatasetModule":
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
@@ -233,7 +226,7 @@ def get_dataset(
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
return dataset_module, template
return dataset_module
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
@@ -280,7 +273,8 @@ def get_dataset(
dataset_module = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
return dataset_module, template
return dataset_module