[misc] update format (#7277)

This commit is contained in:
hoshi-hiyouga
2025-03-13 02:53:08 +08:00
committed by GitHub
parent 4b9d8da5a4
commit 650a9a9057
62 changed files with 384 additions and 288 deletions

View File

@@ -1,4 +1,4 @@
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -92,7 +91,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
@@ -205,7 +204,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
@@ -221,7 +220,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
@@ -247,7 +246,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""Data collator for KTO data."""
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []