Former-commit-id: 05a3be4853b941909e7d193c31e8d62c8c5f879b
This commit is contained in:
hiyouga
2024-06-13 02:48:21 +08:00
parent 103a507b39
commit 49b58fd6af
9 changed files with 19 additions and 19 deletions

View File

@@ -18,8 +18,7 @@ from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
@@ -32,6 +31,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
@@ -123,7 +123,7 @@ def load_single_dataset(
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args)
return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset(
@@ -157,7 +157,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
@@ -169,7 +170,7 @@ def get_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)