[data] shard the dataset to allow multiprocessing when streaming is enabled (#7530)

* Shard the dataset when streaming to allow multiprocessing

* Allow user to not set dataset_shards to ensure backward compatibility
This commit is contained in:
Billy Cao
2025-04-01 15:36:23 +08:00
committed by GitHub
parent d70b3b4bc5
commit 00409ff28a
4 changed files with 12 additions and 4 deletions

View File

@@ -101,10 +101,12 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=data_args.streaming,
use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
if data_args.streaming and data_args.dataset_shards:
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
elif dataset_attr.load_from == "om_hub":
check_version("openmind>=0.8.0", mandatory=True)
@@ -131,10 +133,12 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and not data_args.dataset_shards,
)
if data_args.streaming and data_args.dataset_shards:
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples

View File

@@ -83,6 +83,10 @@ class DataArguments:
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
dataset_shards: Optional[int] = field(
default=None,
metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."},
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},