refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
hiyouga
2023-09-01 19:00:45 +08:00
parent 93be211f80
commit e5b72c6a77
19 changed files with 108 additions and 126 deletions

View File

@@ -11,17 +11,15 @@ class DatasetAttr:
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
stage: Optional[str] = None
ranking: Optional[bool] = False
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
class DataArguments:
@@ -114,21 +112,14 @@ class DataArguments:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"],
stage=dataset_info[name].get("stage", None))
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"script",
dataset_name=dataset_info[name]["script_url"],
stage=dataset_info[name].get("stage", None))
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None),
stage=dataset_info[name].get("stage", None)
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
@@ -137,5 +128,6 @@ class DataArguments:
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)