1. add custom eval dataset support

2. merge load dataset and split dataset function


Former-commit-id: 963d97ba07e7efa3a4544c4d077283d9e112b3ad
This commit is contained in:
codingma
2024-07-05 15:52:10 +08:00
parent 9a1a5f9778
commit 5f2bd04799
15 changed files with 93 additions and 42 deletions

View File

@@ -40,6 +40,7 @@ class DatasetAttr:
subset: Optional[str] = None
folder: Optional[str] = None
num_samples: Optional[int] = None
split: Optional[str] = "train"
# common columns
system: Optional[str] = None
tools: Optional[str] = None
@@ -71,9 +72,9 @@ class DatasetAttr:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
if data_args.dataset is not None:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]:
if dataset is not None:
dataset_names = [ds.strip() for ds in dataset.split(",")]
else:
dataset_names = []
@@ -122,6 +123,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("num_samples", dataset_info[name])
if "split" in dataset_info[name]:
dataset_attr.set_attr("split", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]