Former-commit-id: 3acd151a0f8efdd230c0b0980550795d204a69f7
This commit is contained in:
fzc8578
2025-01-10 21:25:32 +08:00
parent 40382f1387
commit 0aa7ac210f
2 changed files with 10 additions and 22 deletions

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDi
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
@@ -350,26 +349,6 @@ class CpmOPlugin(BasePlugin):
return mm_inputs
def trim_and_pad(self, seq, padding_value=0):
return pad_sequence(seq, batch_first=True, padding_value=padding_value)
def pad_data(self, features):
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["input_ids"] = self.trim_and_pad(
features["input_ids"],
)
features["position_ids"] = self.trim_and_pad(
features["position_ids"],
)
features["labels"] = self.trim_and_pad(
features["labels"],
padding_value=-100,
)
features["attention_mask"] = self.trim_and_pad(
features["attention_mask"],
)
return features
@override
def get_mm_inputs(
self,