fix flashattn + packing

Former-commit-id: 4adc6ce4abc718c25f39b316bfc3352d0d01ed1e
This commit is contained in:
hiyouga
2024-07-21 17:07:45 +08:00
parent adff3e5050
commit a770afbff2
4 changed files with 25 additions and 20 deletions

View File

@@ -38,7 +38,9 @@ def _encode_supervised_example(
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
@@ -54,22 +56,22 @@ def _encode_supervised_example(
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len:
if total_length >= cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if data_args.train_on_prompt:
if train_on_prompt:
source_label = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
if mask_history and turn_idx != len(encoded_pairs) - 1:
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
@@ -112,7 +114,9 @@ def preprocess_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
@@ -150,7 +154,9 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
@@ -163,7 +169,7 @@ def preprocess_packed_supervised_dataset(
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
for i, length in enumerate(knapsack):