support multiturn training like FastChat
Former-commit-id: 629cafb1a09924e82d7ea1f9fba318d3f5593196
This commit is contained in:
@@ -421,18 +421,17 @@ def preprocess_data(
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def format_example(examples):
|
||||
def get_dialog(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix)
|
||||
yield prompt, answer
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
# build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS])
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
@@ -446,28 +445,29 @@ def preprocess_data(
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
|
||||
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
for dialog in get_dialog(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
|
||||
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `X <s>` and labels with format `Y <s>`
|
||||
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
@@ -484,9 +484,11 @@ def preprocess_data(
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `X <s> Y1 </s>` and `X <s> Y2 </s>`
|
||||
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
Reference in New Issue
Block a user