optimize data loading logic
Former-commit-id: 58f669b384582ac90e85de835f1f44f7003f9ec0
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||
|
||||
from llmtuner.data.utils import checksum
|
||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||
@@ -22,6 +22,13 @@ def get_dataset(
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user