improve aligner

Former-commit-id: cc7296b92e10c24967fc753393275b71d300683f
This commit is contained in:
hiyouga
2024-02-10 16:39:19 +08:00
parent a41fa6e730
commit 1955a8ea5a
10 changed files with 80 additions and 64 deletions

View File

@@ -49,40 +49,32 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.function_tag: Role.FUNCTION,
dataset_attr.system_tag: Role.SYSTEM,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if len(messages) <= 1:
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
prompt = []
response = []
n_sys = 0
aligned_messages = []
for turn_idx, message in enumerate(messages):
if dataset_attr.system_tag and message[dataset_attr.role_tag] == dataset_attr.system_tag:
outputs["system"].append(message[dataset_attr.content_tag])
n_sys = 1
continue
if (turn_idx - n_sys) % 2 == 0:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append(
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if len(prompt) % 2 == 1:
# Last message was neither from assistant nor function
prompt.pop(-1)
last_message = prompt.pop(-1)
response.append(last_message)
outputs["prompt"].append(prompt)
outputs["response"].append(response)
if n_sys == 0:
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs
@@ -93,8 +85,8 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}]
response: [{"role": "assistant", "content": "..."}]
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
"""