fix inference, add prompt template

Former-commit-id: 3940e50c71472b210bbc1b01248bf85a191c4065
This commit is contained in:
hiyouga
2023-06-07 10:52:35 +08:00
parent 12094c1db5
commit 3da427a665
6 changed files with 44 additions and 81 deletions

View File

@@ -37,6 +37,11 @@ from .config import (
FinetuningArguments
)
from .template import (
prompt_template_alpaca,
prompt_template_ziya
)
from .other import (
get_logger,
load_trainable_params,
@@ -224,6 +229,7 @@ def load_pretrained(
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model)
@@ -395,39 +401,19 @@ def preprocess_data(
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
# support question with a single answer or multiple answers
def format_example_alpaca(examples):
def format_example(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix
if examples["history"][i]:
for old_query, response in examples["history"][i]:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer
def format_example_ziya(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = ""
if examples["history"][i]:
for old_query, response in examples["history"][i]:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
prompt = prompt_template(query, examples["history"][i])
prompt = prefix + prompt
yield prompt, answer
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
text_ids = tokenizer(examples["prompt"])["input_ids"]