[misc] update format (#7277)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||
@@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
@@ -92,7 +91,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
for feature in features:
|
||||
@@ -205,7 +204,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
@@ -221,7 +220,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
r"""Pad batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
@@ -247,7 +246,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""Data collator for KTO data."""
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
@@ -37,7 +36,7 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
|
||||
def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if not isinstance(medias, list):
|
||||
medias = [medias] if medias is not None else []
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
@@ -30,7 +29,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
|
||||
SLOTS = list[Union[str, set[str], dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -157,7 +156,7 @@ def _load_single_dataset(
|
||||
|
||||
|
||||
def _get_merged_dataset(
|
||||
dataset_names: Optional[Sequence[str]],
|
||||
dataset_names: Optional[list[str]],
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
@@ -83,9 +82,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
|
||||
) -> list[list[int]]:
|
||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
It is slightly different with the original token type ids where the prompt part is 0.
|
||||
@@ -120,7 +117,7 @@ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcess
|
||||
return batch_token_type_ids
|
||||
|
||||
|
||||
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
|
||||
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
|
||||
r"""Make nested list of images."""
|
||||
batch_images = []
|
||||
for imglen in imglens:
|
||||
@@ -140,9 +137,9 @@ class MMPluginMixin:
|
||||
def _validate_input(
|
||||
self,
|
||||
processor: Optional["MMProcessor"],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
) -> None:
|
||||
r"""Validate if this model accepts the input modalities."""
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
@@ -202,7 +199,7 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
for image in images:
|
||||
@@ -223,7 +220,7 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
for video in videos:
|
||||
@@ -241,7 +238,7 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
||||
def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results = []
|
||||
for audio in audios:
|
||||
@@ -257,9 +254,9 @@ class MMPluginMixin:
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
@@ -335,10 +332,10 @@ class MMPluginMixin:
|
||||
class BasePlugin(MMPluginMixin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Pre-process input messages before tokenization for VLMs."""
|
||||
@@ -349,9 +346,9 @@ class BasePlugin(MMPluginMixin):
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
@@ -361,13 +358,13 @@ class BasePlugin(MMPluginMixin):
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
r"""Build batched multimodal inputs for VLMs.
|
||||
@@ -392,10 +389,10 @@ class Gemma3Plugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -420,13 +417,13 @@ class Gemma3Plugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -441,10 +438,10 @@ class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -481,10 +478,10 @@ class LlavaNextPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -523,10 +520,10 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -586,10 +583,10 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -686,9 +683,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
@@ -757,13 +754,13 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -828,10 +825,10 @@ class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -850,13 +847,13 @@ class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -885,10 +882,10 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -912,9 +909,9 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
@@ -931,13 +928,13 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -952,10 +949,10 @@ class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -995,13 +992,13 @@ class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -1015,10 +1012,10 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -1056,13 +1053,13 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -1090,7 +1087,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
def _regularize_videos(
|
||||
self, videos: Sequence["VideoInput"], **kwargs
|
||||
self, videos: list["VideoInput"], **kwargs
|
||||
) -> tuple[list[list["ImageObject"]], list[float]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
@@ -1118,9 +1115,9 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
@@ -1149,10 +1146,10 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -1204,13 +1201,13 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[list[int]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
@@ -1229,10 +1226,10 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -91,7 +90,7 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
|
||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .feedback import FeedbackDatasetProcessor
|
||||
from .pairwise import PairwiseDatasetProcessor
|
||||
from .pretrain import PretrainDatasetProcessor
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -31,14 +30,14 @@ logger = logging.get_logger(__name__)
|
||||
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
kl_response: Sequence[dict[str, str]],
|
||||
prompt: list[dict[str, str]],
|
||||
response: list[dict[str, str]],
|
||||
kl_response: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
) -> tuple[list[int], list[int], list[int], list[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
|
||||
class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
prompt: list[dict[str, str]],
|
||||
response: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
) -> tuple[list[int], list[int], list[int], list[int]]:
|
||||
chosen_messages = self.template.mm_plugin.process_messages(
|
||||
prompt + [response[0]], images, videos, audios, self.processor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import bisect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -46,7 +45,7 @@ class DatasetProcessor(ABC):
|
||||
...
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
def search_for_fit(numbers: list[int], capacity: int) -> int:
|
||||
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -33,13 +32,13 @@ logger = logging.get_logger(__name__)
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
prompt: list[dict[str, str]],
|
||||
response: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
|
||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
prompt: list[dict[str, str]],
|
||||
response: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@@ -57,7 +56,7 @@ class Template:
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
@@ -73,7 +72,7 @@ class Template:
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
@@ -115,7 +114,7 @@ class Template:
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
) -> list[list[int]]:
|
||||
@@ -316,7 +315,7 @@ class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
) -> list[list[int]]:
|
||||
@@ -391,7 +390,7 @@ def register_template(
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Optional[Sequence[str]] = None,
|
||||
stop_words: Optional[list[str]] = None,
|
||||
thought_words: Optional[tuple[str, str]] = None,
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user