Former-commit-id: 05a3be4853b941909e7d193c31e8d62c8c5f879b
This commit is contained in:
hiyouga
2024-06-13 02:48:21 +08:00
parent 103a507b39
commit 49b58fd6af
9 changed files with 19 additions and 19 deletions

View File

@@ -10,6 +10,7 @@ from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .parser import DatasetAttr
@@ -175,7 +176,10 @@ def convert_sharegpt(
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
@@ -208,7 +212,7 @@ def align_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)