update data readme (zh)

Former-commit-id: b32fb3a984c681732b82f6544d6c05a98c34cf4c
This commit is contained in:
hiyouga
2023-11-02 23:42:49 +08:00
parent b77c745b1a
commit 4bb643e685
6 changed files with 105 additions and 24 deletions

View File

@@ -69,7 +69,7 @@ def get_dataset(
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": []}
for msg_list in examples[dataset_attr.prompt]:
for msg_list in examples[dataset_attr.messages]:
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
@@ -78,15 +78,15 @@ def get_dataset(
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.query]
assistant_role = msg_list[idx + 1][dataset_attr.query]
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.query] != user_role
or msg_list[idx+1][dataset_attr.query] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.response], msg_list[idx + 1][dataset_attr.response]))
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0:
outputs["prompt"].append(msg_pairs[-1][0])

View File

@@ -19,6 +19,9 @@ class DatasetAttr:
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
def __repr__(self) -> str:
return self.dataset_name
@@ -155,6 +158,9 @@ class DataArguments:
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)