fix packing for eager/sdpa attn
Former-commit-id: 735a033ceb7f2da6da71d138ea091d8a665411a9
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
from typing import Any, Dict, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
@@ -62,13 +62,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
@@ -100,7 +118,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
||||
Reference in New Issue
Block a user