fix #2282 and update tool prompt

Former-commit-id: 1c412f803866bde32b76f7c26c7b464b6b3651f3
This commit is contained in:
hiyouga
2024-01-22 22:27:30 +08:00
parent 1fe1ca1c8b
commit 75be329994
5 changed files with 25 additions and 20 deletions

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from .utils import Role
if TYPE_CHECKING:
@@ -51,7 +52,7 @@ def preprocess_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
@@ -93,7 +94,7 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
@@ -137,10 +138,14 @@ def preprocess_unsupervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
if len(examples["prompt"][i]) % 2 != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
@@ -164,7 +169,7 @@ def preprocess_pairwise_dataset(
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]