Merge pull request #4224 from chuan298/main
Implement efficient packing without cross-contamination attention Former-commit-id: ac382cc9fe4ec483658fd54f07f9a123788ce1b1
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,6 +22,46 @@ import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```
|
||||
[[1, 1, 2, 2, 2, 0]]
|
||||
```
|
||||
->
|
||||
```
|
||||
[
|
||||
[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[o, o, x, x, x, x],
|
||||
[x, x, o, x, x, x],
|
||||
[x, x, o, o, x, x],
|
||||
[x, x, o, o, o, x],
|
||||
[x, x, o, x, x, x],
|
||||
]
|
||||
]
|
||||
]
|
||||
```
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
bsz, seq_len = attention_mask_with_indices.size()
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||
# Create a block-diagonal mask.
|
||||
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||
# Use the lower triangular mask to zero out the upper triangular part
|
||||
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||
# Invert the attention mask.
|
||||
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
|
||||
@@ -160,22 +160,30 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
Reference in New Issue
Block a user