enable cutoff len

Former-commit-id: e9513d300c338dfcae98eee7d057bfd00da2da0e
This commit is contained in:
hiyouga
2024-01-18 12:25:42 +08:00
parent d8affd3967
commit e4a424cb6a
6 changed files with 46 additions and 14 deletions

View File

@@ -58,7 +58,7 @@ def preprocess_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)):
if data_args.train_on_prompt:
source_mask = source_ids
@@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i]
tokenizer, messages, examples["system"][i], examples["tools"][i]
)):
if data_args.train_on_prompt:
source_mask = source_ids
@@ -141,7 +141,7 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
if template.efficient_eos:
@@ -170,10 +170,10 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
if template.efficient_eos: