support ms dataset
Former-commit-id: 98638b35dc24045ac17b9b01d08d3a02372acef3
This commit is contained in:
@@ -24,7 +24,7 @@ def get_dataset(
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
if dataset_attr.load_from in ("hf_hub", "ms_hub"):
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
@@ -53,15 +53,22 @@ def get_dataset(
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
)
|
||||
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub":
|
||||
from modelscope import MsDataset
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
).to_hf_dataset()
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
Reference in New Issue
Block a user