[breaking change] refactor data pipeline (#6901)

* refactor data

* rename file

Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
hoshi-hiyouga
2025-02-13 00:39:20 +08:00
committed by GitHub
parent 80b89978d9
commit 46203856fc
27 changed files with 1145 additions and 1132 deletions

View File

@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset
from .converter import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .processor import (
FeedbackDatasetProcessor,
PackedSupervisedDatasetProcessor,
PairwiseDatasetProcessor,
PretrainDatasetProcessor,
SupervisedDatasetProcessor,
UnsupervisedDatasetProcessor,
)
if TYPE_CHECKING:
@@ -35,6 +42,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, ModelArguments
from .data_utils import DatasetModule
from .parser import DatasetAttr
from .processor import DatasetProcessor
from .template import Template
@@ -158,7 +166,7 @@ def _get_merged_dataset(
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
Returns the merged datasets in the standard format.
"""
if dataset_names is None:
return None
@@ -173,6 +181,48 @@ def _get_merged_dataset(
return merge_dataset(datasets, data_args, seed=training_args.seed)
def _get_dataset_processor(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
dataset_processor_class = PackedSupervisedDatasetProcessor
else:
dataset_processor_class = SupervisedDatasetProcessor
elif stage == "rm":
dataset_processor_class = PairwiseDatasetProcessor
elif stage == "kto":
dataset_processor_class = FeedbackDatasetProcessor
else:
dataset_processor_class = UnsupervisedDatasetProcessor
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
@@ -189,7 +239,7 @@ def _get_preprocessed_dataset(
if dataset is None:
return None
preprocess_func, print_function = get_preprocess_and_print_func(
dataset_processor = _get_dataset_processor(
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
)
column_names = list(next(iter(dataset)).keys())
@@ -202,7 +252,7 @@ def _get_preprocessed_dataset(
)
dataset = dataset.map(
preprocess_func,
dataset_processor.preprocess_dataset,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
@@ -212,7 +262,7 @@ def _get_preprocessed_dataset(
if training_args.should_log:
try:
print("eval example:" if is_eval else "training example:")
print_function(next(iter(dataset)))
dataset_processor.print_data_example(next(iter(dataset)))
except StopIteration:
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")