1. add custom eval dataset support

2. merge load dataset and split dataset function


Former-commit-id: 963d97ba07e7efa3a4544c4d077283d9e112b3ad
This commit is contained in:
codingma
2024-07-05 15:52:10 +08:00
parent 9a1a5f9778
commit 5f2bd04799
15 changed files with 93 additions and 42 deletions

View File

@@ -15,7 +15,7 @@
import inspect
import os
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union, Dict
import numpy as np
from datasets import load_dataset, load_from_disk
@@ -24,10 +24,10 @@ from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .template import get_template_and_fix_tokenizer, Template
if TYPE_CHECKING:
@@ -91,7 +91,7 @@ def load_single_dataset(
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
@@ -111,7 +111,7 @@ def load_single_dataset(
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
@@ -140,20 +140,17 @@ def load_single_dataset(
return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset(
def load_and_preprocess(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
template: "Template",
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load tokenized dataset
if data_args.tokenized_path is not None:
if not is_eval and data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
@@ -165,9 +162,21 @@ def get_dataset(
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
if is_eval and data_args.eval_tokenized_path is not None:
if has_tokenized_data(data_args.eval_tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.eval_tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
@@ -190,13 +199,20 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.tokenized_path is not None:
if not is_eval and data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
if is_eval and data_args.eval_tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.eval_tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path))
sys.exit(0)
if training_args.should_log:
try:
@@ -208,3 +224,24 @@ def get_dataset(
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None
) -> Dict[str, "Dataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor)
if data_args.eval_dataset or data_args.eval_tokenized_path:
eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True)
return {"train_dataset": train_dataset, "eval_dataset": eval_dataset}
else:
return split_dataset(train_dataset, data_args, training_args)