[data] shard the dataset to allow multiprocessing when streaming is enabled (#7530)
* Shard the dataset when streaming to allow multiprocessing * Allow user to not set dataset_shards to ensure backward compatibility
This commit is contained in:
@@ -101,10 +101,12 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=data_args.streaming,
|
||||
use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
if data_args.streaming and data_args.dataset_shards:
|
||||
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
@@ -131,10 +133,12 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=data_args.streaming,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
streaming=data_args.streaming and not data_args.dataset_shards,
|
||||
)
|
||||
if data_args.streaming and data_args.dataset_shards:
|
||||
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
|
||||
Reference in New Issue
Block a user