follow #5115
Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
This commit is contained in:
@@ -53,8 +53,11 @@ def _encode_supervised_example(
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
if mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
@@ -66,20 +69,23 @@ def _encode_supervised_example(
|
||||
|
||||
if train_on_prompt:
|
||||
source_label = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
elif template.efficient_eos:
|
||||
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||
else:
|
||||
source_label = [IGNORE_INDEX] * source_len
|
||||
|
||||
if mask_history:
|
||||
target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len
|
||||
if mask_history and turn_idx != 0: # train on the last turn only
|
||||
target_label = [IGNORE_INDEX] * target_len
|
||||
else:
|
||||
target_label = target_ids
|
||||
|
||||
if mask_history: # reversed sequences
|
||||
input_ids = source_ids + target_ids + input_ids
|
||||
labels = source_label + target_label + labels
|
||||
else:
|
||||
target_label = target_ids
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_label + target_label
|
||||
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
Reference in New Issue
Block a user