[breaking change] refactor data pipeline (#6901)
* refactor data * rename file Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
@@ -45,7 +45,7 @@ class DatasetAttr:
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
audios: Optional[str] = None
|
||||
# rlhf columns
|
||||
# dpo columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
@@ -71,6 +71,26 @@ class DatasetAttr:
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: Dict[str, Any]) -> None:
|
||||
self.set_attr("formatting", attr, default="alpaca")
|
||||
self.set_attr("ranking", attr, default=False)
|
||||
self.set_attr("subset", attr)
|
||||
self.set_attr("split", attr, default="train")
|
||||
self.set_attr("folder", attr)
|
||||
self.set_attr("num_samples", attr)
|
||||
|
||||
if "columns" in attr:
|
||||
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
|
||||
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
||||
for column_name in column_names:
|
||||
self.set_attr(column_name, attr["columns"])
|
||||
|
||||
if "tags" in attr:
|
||||
tag_names = ["role_tag", "content_tag"]
|
||||
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
|
||||
for tag in tag_names:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||
r"""
|
||||
@@ -128,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
"user_tag",
|
||||
"assistant_tag",
|
||||
"observation_tag",
|
||||
"function_tag",
|
||||
"system_tag",
|
||||
)
|
||||
for tag in tag_names:
|
||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||
|
||||
dataset_attr.join(dataset_info[name])
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
return dataset_list
|
||||
|
||||
Reference in New Issue
Block a user