Former-commit-id: 4715e5c5b8040b21e5f401f7e969b9fd2757d520
This commit is contained in:
hiyouga
2024-07-18 22:06:12 +08:00
parent 86e009b504
commit 4c1513a845
7 changed files with 56 additions and 36 deletions

View File

@@ -63,18 +63,19 @@ def _encode_supervised_example(
total_length += source_len + target_len
if data_args.train_on_prompt:
source_mask = source_ids
source_label = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_mask = [IGNORE_INDEX] * source_len
source_label = [IGNORE_INDEX] * source_len
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
input_ids += source_ids + target_ids
if data_args.train_last_turn_only and turn_idx != len(encoded_pairs) - 1:
labels += source_mask + [IGNORE_INDEX] * len(target_ids)
else:
labels += source_mask + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]