fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
This commit is contained in:
@@ -92,14 +92,13 @@ def get_dataset(
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
if data_args.streaming:
|
||||
features = dataset.features
|
||||
features["prefix"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
features["system"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
|
||||
else:
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ def preprocess_dataset(
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = examples["history"][i] if "history" in examples else None
|
||||
prefix = examples["prefix"][i] if "prefix" in examples else None
|
||||
yield query, response, history, prefix
|
||||
system = examples["system"][i] if "system" in examples else None
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
||||
@@ -56,10 +56,10 @@ def preprocess_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
for query, response, history, system in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix):
|
||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
@@ -78,11 +78,11 @@ def preprocess_dataset(
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix)
|
||||
for query, response, history, system in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
@@ -98,9 +98,9 @@ def preprocess_dataset(
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
||||
for query, response, history, system in construct_example(examples):
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if len(prompt_ids) > data_args.max_source_length:
|
||||
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||
|
||||
Reference in New Issue
Block a user