[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
@dataclass
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
r"""Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
|
||||
"""
|
||||
@@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
for feature in features:
|
||||
@@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
rope_index_kwargs = {
|
||||
@@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
r"""Data collator for 4d attention mask."""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
@@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
r"""Pad batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
@@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
r"""Data collator for KTO data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
||||
Reference in New Issue
Block a user