simplify readme

Former-commit-id: 0da6ec2d516326fe9c7583ba71cd1778eb838178
This commit is contained in:
hiyouga
2024-04-02 20:07:43 +08:00
parent 117b67ea30
commit b12176d818
24 changed files with 244 additions and 890 deletions

View File

@@ -6,6 +6,7 @@ from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import is_path_available
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
@@ -122,11 +123,12 @@ def get_dataset(
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
# Load tokenized dataset
if data_args.tokenized_path is not None:
if not is_path_available(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -158,10 +160,13 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
exit(0)
if training_args.should_log:
try: