mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[breaking change] refactor data pipeline (#6901)
* refactor data * rename file Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.misc import check_version, has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .converter import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .processor import (
|
||||
FeedbackDatasetProcessor,
|
||||
PackedSupervisedDatasetProcessor,
|
||||
PairwiseDatasetProcessor,
|
||||
PretrainDatasetProcessor,
|
||||
SupervisedDatasetProcessor,
|
||||
UnsupervisedDatasetProcessor,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .data_utils import DatasetModule
|
||||
from .parser import DatasetAttr
|
||||
from .processor import DatasetProcessor
|
||||
from .template import Template
|
||||
|
||||
|
||||
@@ -158,7 +166,7 @@ def _get_merged_dataset(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Gets the merged datasets in the standard format.
|
||||
Returns the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
@@ -173,6 +181,48 @@ def _get_merged_dataset(
|
||||
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||
|
||||
|
||||
def _get_dataset_processor(
|
||||
data_args: "DataArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> "DatasetProcessor":
|
||||
r"""
|
||||
Returns the corresponding dataset processor.
|
||||
"""
|
||||
if stage == "pt":
|
||||
dataset_processor_class = PretrainDatasetProcessor
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
return TypedSequence.__init__(
|
||||
self,
|
||||
data,
|
||||
type=kwargs.pop("type", None),
|
||||
try_type=kwargs.pop("try_type", None),
|
||||
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||
)
|
||||
|
||||
OptimizedTypedSequence.__init__ = __init__
|
||||
dataset_processor_class = PackedSupervisedDatasetProcessor
|
||||
else:
|
||||
dataset_processor_class = SupervisedDatasetProcessor
|
||||
|
||||
elif stage == "rm":
|
||||
dataset_processor_class = PairwiseDatasetProcessor
|
||||
elif stage == "kto":
|
||||
dataset_processor_class = FeedbackDatasetProcessor
|
||||
else:
|
||||
dataset_processor_class = UnsupervisedDatasetProcessor
|
||||
|
||||
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
@@ -189,7 +239,7 @@ def _get_preprocessed_dataset(
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
dataset_processor = _get_dataset_processor(
|
||||
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
@@ -202,7 +252,7 @@ def _get_preprocessed_dataset(
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
dataset_processor.preprocess_dataset,
|
||||
batched=True,
|
||||
batch_size=data_args.preprocessing_batch_size,
|
||||
remove_columns=column_names,
|
||||
@@ -212,7 +262,7 @@ def _get_preprocessed_dataset(
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print("eval example:" if is_eval else "training example:")
|
||||
print_function(next(iter(dataset)))
|
||||
dataset_processor.print_data_example(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
|
||||
Reference in New Issue
Block a user