support multiturn training like FastChat

Former-commit-id: 629cafb1a09924e82d7ea1f9fba318d3f5593196
This commit is contained in:
hiyouga
2023-06-14 22:27:39 +08:00
parent 6f655e3916
commit aa1bb8a9a2
5 changed files with 166 additions and 108 deletions

View File

@@ -421,18 +421,17 @@ def preprocess_data(
prompt_template = Template(data_args.prompt_template)
# support question with a single answer or multiple answers
def format_example(examples):
def get_dialog(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix)
yield prompt, answer
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
text_ids = tokenizer(examples["prompt"])["input_ids"]
# build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS])
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
@@ -446,28 +445,29 @@ def preprocess_data(
}
def preprocess_supervised_dataset(examples):
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
for dialog in get_dialog(examples):
input_ids, labels = [], []
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
return model_inputs
def preprocess_unsupervised_dataset(examples):
# build inputs with format `X <s>` and labels with format `Y <s>`
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
@@ -484,9 +484,11 @@ def preprocess_data(
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `X <s> Y1 </s>` and `X <s> Y2 </s>`
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
model_inputs = {"accept_ids": [], "reject_ids": []}
for prompt, answer in format_example(examples):
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)