1. add custom eval dataset support
2. merge load dataset and split dataset function Former-commit-id: 963d97ba07e7efa3a4544c4d077283d9e112b3ad
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
@@ -24,10 +24,10 @@ from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .template import get_template_and_fix_tokenizer, Template
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -91,7 +91,7 @@ def load_single_dataset(
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
@@ -111,7 +111,7 @@ def load_single_dataset(
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
@@ -140,20 +140,17 @@ def load_single_dataset(
|
||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
def load_and_preprocess(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if not is_eval and data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
@@ -165,9 +162,21 @@ def get_dataset(
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
if is_eval and data_args.eval_tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.eval_tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.eval_tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path))
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
@@ -190,13 +199,20 @@ def get_dataset(
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.tokenized_path is not None:
|
||||
if not is_eval and data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
|
||||
sys.exit(0)
|
||||
if is_eval and data_args.eval_tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.eval_tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path))
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
@@ -208,3 +224,24 @@ def get_dataset(
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
) -> Dict[str, "Dataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor)
|
||||
|
||||
if data_args.eval_dataset or data_args.eval_tokenized_path:
|
||||
eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True)
|
||||
return {"train_dataset": train_dataset, "eval_dataset": eval_dataset}
|
||||
else:
|
||||
return split_dataset(train_dataset, data_args, training_args)
|
||||
|
||||
@@ -40,6 +40,7 @@ class DatasetAttr:
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
split: Optional[str] = "train"
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
@@ -71,9 +72,9 @@ class DatasetAttr:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
if data_args.dataset is not None:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
||||
def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]:
|
||||
if dataset is not None:
|
||||
dataset_names = [ds.strip() for ds in dataset.split(",")]
|
||||
else:
|
||||
dataset_names = []
|
||||
|
||||
@@ -122,6 +123,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
if "split" in dataset_info[name]:
|
||||
dataset_attr.set_attr("split", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
|
||||
@@ -33,6 +33,11 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use for eval during training. "
|
||||
"Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
@@ -105,6 +110,10 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
)
|
||||
eval_tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized eval datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_dpo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
@@ -71,7 +71,7 @@ def run_dpo(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_kto(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = KTODataCollatorWithPadding(
|
||||
@@ -68,7 +68,7 @@ def run_kto(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -43,7 +43,7 @@ def run_ppo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
@@ -63,7 +63,7 @@ def run_ppo(
|
||||
model=model,
|
||||
reward_model=reward_model,
|
||||
ref_model=ref_model,
|
||||
dataset=dataset,
|
||||
dataset=dataset_module["train_dataset"],
|
||||
data_collator=data_collator,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ def run_pt(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
@@ -54,7 +54,7 @@ def run_pt(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_rm(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
@@ -57,7 +57,7 @@ def run_rm(
|
||||
callbacks=callbacks,
|
||||
compute_metrics=compute_accuracy,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -81,7 +81,7 @@ def run_rm(
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict")
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
@@ -43,7 +43,7 @@ def run_sft(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
@@ -76,7 +76,7 @@ def run_sft(
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
@@ -105,12 +105,12 @@ def run_sft(
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(dataset, predict_results)
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
Reference in New Issue
Block a user