fix packing for eager/sdpa attn

Former-commit-id: 735a033ceb7f2da6da71d138ea091d8a665411a9
This commit is contained in:
hiyouga
2024-07-04 01:52:43 +08:00
parent a90c6306f8
commit 3d219b91b9
9 changed files with 51 additions and 20 deletions

View File

@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Sequence
from typing import Any, Dict, Literal, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@@ -62,13 +62,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
return attention_mask_4d
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
@@ -100,7 +118,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []