Former-commit-id: 4e1af5a5d39d9e2f374c1372e2d67120c63fea09
This commit is contained in:
hiyouga
2023-12-09 20:53:18 +08:00
parent 2e6ed731cf
commit f3ffa8310f
4 changed files with 18 additions and 13 deletions

View File

@@ -24,27 +24,27 @@ def get_dataset(
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_files = None
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "file":
data_path, data_name = None, None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
@@ -56,6 +56,7 @@ def get_dataset(
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,