support val set in streaming mode

Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
This commit is contained in:
hiyouga
2023-08-09 23:00:26 +08:00
parent 972bfa700a
commit 467d571206
10 changed files with 58 additions and 50 deletions

View File

@@ -104,9 +104,9 @@ def preprocess_dataset(
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length:
accept_ids = accept_ids[:data_args.max_target_length - 1]
accept_ids = accept_ids[:data_args.max_target_length]
if len(reject_ids) > data_args.max_target_length:
reject_ids = reject_ids[:data_args.max_target_length - 1]
reject_ids = reject_ids[:data_args.max_target_length]
accept_ids = source_ids + accept_ids
reject_ids = source_ids + reject_ids
@@ -166,8 +166,5 @@ def preprocess_dataset(
**kwargs
)
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset)))
return dataset

View File

@@ -1,15 +1,29 @@
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Union
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}