support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
This commit is contained in:
@@ -104,9 +104,9 @@ def preprocess_dataset(
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length:
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
accept_ids = accept_ids[:data_args.max_target_length]
|
||||
if len(reject_ids) > data_args.max_target_length:
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
reject_ids = reject_ids[:data_args.max_target_length]
|
||||
|
||||
accept_ids = source_ids + accept_ids
|
||||
reject_ids = source_ids + reject_ids
|
||||
@@ -166,8 +166,5 @@ def preprocess_dataset(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||
|
||||
print_function(next(iter(dataset)))
|
||||
return dataset
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import TrainingArguments
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
||||
if do_train:
|
||||
if dev_ratio > 1e-6: # Split the dataset
|
||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
|
||||
Reference in New Issue
Block a user