[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"SFTDataCollatorWith4DAttentionMask",
"Template",
"get_dataset",
"get_template_and_fix_tokenizer",
"split_dataset",
]

View File

@@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np
import torch
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
r"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
@@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
r"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
@@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
@@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = {
@@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
r"""Data collator for 4d attention mask."""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
@@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
@@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
r"""Data collator for KTO data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []

View File

@@ -14,8 +14,9 @@
import os
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging
from .data_utils import Role
@@ -36,10 +37,8 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
elif len(medias) == 0:
@@ -57,16 +56,14 @@ class DatasetConverter:
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
@@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -216,10 +213,8 @@ DATASET_CONVERTERS = {
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r"""Register a new dataset converter."""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
@@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
r"""Get a dataset converter."""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
@@ -242,17 +235,17 @@ def align_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
r"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:

View File

@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
@@ -29,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
@unique
@@ -43,15 +44,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
r"""Merge multiple datasets to a unified dataset."""
if len(all_datasets) == 1:
return all_datasets[0]
@@ -78,14 +77,13 @@ def merge_dataset(
def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
r"""Split the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
@@ -120,10 +118,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
r"""Convert dataset or dataset dict to dataset module."""
dataset_module: DatasetModule = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]

View File

@@ -16,7 +16,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import override
@@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@@ -13,7 +13,8 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
@@ -54,9 +55,7 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
r"""Load a single dataset and aligns it to the standard format."""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
@@ -164,10 +163,8 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r"""
Returns the merged datasets in the standard format.
"""
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
@@ -192,9 +189,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
r"""Return the corresponding dataset processor."""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
@@ -236,9 +231,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
@@ -284,9 +277,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
r"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):

View File

@@ -1,10 +1,11 @@
import inspect
import math
import re
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
import numpy as np
import torch
@@ -58,12 +59,12 @@ if TYPE_CHECKING:
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
) -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
@@ -87,11 +88,9 @@ class MMPluginMixin:
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
r"""Validate if this model accepts the input modalities."""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
@@ -119,9 +118,7 @@ class MMPluginMixin:
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
r"""Pre-process a single image."""
if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
@@ -139,10 +136,8 @@ class MMPluginMixin:
def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r"""
Computes video sample indices according to fps.
"""
) -> list[int]:
r"""Compute video sample indices according to fps."""
total_frames = video_stream.frames
if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
@@ -151,10 +146,8 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
if isinstance(image, str):
@@ -174,16 +167,14 @@ class MMPluginMixin:
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@@ -194,10 +185,8 @@ class MMPluginMixin:
return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r"""
Regularizes audios to avoid error. Including reading and resampling.
"""
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
r"""Regularizes audios to avoid error. Including reading and resampling."""
results = []
for audio in audios:
if isinstance(audio, str):
@@ -216,9 +205,8 @@ class MMPluginMixin:
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
@@ -229,9 +217,9 @@ class MMPluginMixin:
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
@@ -278,31 +266,27 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
) -> list[dict[str, str]]:
r"""Pre-processes input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
) -> tuple[list[int], Optional[list[int]]]:
r"""Pre-processes token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
@@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(processor, images, videos, audios)
return {}
@@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
@@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
@@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0:
@@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = []
@@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: List[int],
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
imglens: list[int],
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
@@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin):
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
@@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs:
@@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
@@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
@@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None)
@@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
@@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
) -> tuple[list[list["ImageObject"]], list[float]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
@@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
@@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
@@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@@ -1277,10 +1262,8 @@ PLUGINS = {
}
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
r"""
Registers a multimodal plugin.
"""
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
r"""Register a multimodal plugin."""
if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.")
@@ -1293,9 +1276,7 @@ def get_mm_plugin(
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
) -> "BasePlugin":
r"""
Gets plugin for multimodal inputs.
"""
r"""Get plugin for multimodal inputs."""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.")

View File

@@ -14,8 +14,9 @@
import json
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Literal, Optional
from transformers.utils import cached_file
@@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
r"""Dataset attributes."""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
@@ -68,10 +67,10 @@ class DatasetAttr:
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
@@ -92,10 +91,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
Gets the attributes of the datasets.
"""
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []
@@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
dataset_list: List["DatasetAttr"] = []
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():

View File

@@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]

View File

@@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
kl_response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
@@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
@@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
@@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))

View File

@@ -17,14 +17,14 @@
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from typing import Any
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -14,8 +14,9 @@
import bisect
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
@@ -27,9 +28,7 @@ if TYPE_CHECKING:
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
r"""A class for data processors."""
template: "Template"
tokenizer: "PreTrainedTokenizer"
@@ -37,32 +36,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r"""Build model inputs from the examples."""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
def print_data_example(self, example: dict[str, list[int]]) -> None:
r"""Print a data example to stdout."""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
@@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r"""Compute the real sequence length after truncation by the cutoff_len."""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
@@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
@@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`

View File

@@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ..data_utils import Role
@@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
if len(response) == 1:
messages = prompt + response
else:
@@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
@@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))

View File

@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override
@@ -46,8 +47,8 @@ class Template:
format_tools: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: Tuple[str, str]
stop_words: list[str]
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
@@ -56,13 +57,11 @@ class Template:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
@@ -74,36 +73,28 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract tool message."""
return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
r"""
Returns stop token ids.
"""
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Return stop token ids."""
stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
for elem in elements:
if isinstance(elem, str):
@@ -124,14 +115,14 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp.
"""
system = system or self.default_system
encoded_messages = []
@@ -161,9 +152,7 @@ class Template:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
Adds or replaces eos token to the tokenizer.
"""
r"""Add or replace eos token to the tokenizer."""
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
@@ -176,9 +165,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
r"""Add eos token and pad token to the tokenizer."""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
@@ -204,16 +191,12 @@ class Template:
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
r"""Escape single quotes in content."""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
r"""Convert slots to jinja template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@@ -235,9 +218,7 @@ class Template:
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
r"""Return the jinja template."""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
@@ -265,9 +246,7 @@ class Template:
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
r"""Replace the jinja template in the tokenizer."""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
@@ -278,9 +257,7 @@ class Template:
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
r"""Convert slots to ollama template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@@ -302,9 +279,7 @@ class Template:
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
r"""Return the ollama template."""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
@@ -316,8 +291,7 @@ class Template:
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
r"""Return the ollama modelfile.
TODO: support function calling.
"""
@@ -340,10 +314,10 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: str,
tools: str,
) -> List[List[int]]:
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
@@ -402,7 +376,7 @@ class Llama2Template(Template):
return jinja_template
TEMPLATES: Dict[str, "Template"] = {}
TEMPLATES: dict[str, "Template"] = {}
def register_template(
@@ -416,15 +390,14 @@ def register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
r"""Register a chat template.
To add the following chat template:
```
@@ -472,9 +445,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Extracts a chat template from the tokenizer.
"""
r"""Extract a chat template from the tokenizer."""
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
@@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
r"""Get chat template and fixes the tokenizer."""
if data_args.template is None:
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
@@ -1149,7 +1118,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),

View File

@@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Union
from typing_extensions import override
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass
class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
"""Base class for tool utilities."""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
def tool_formatter(tools: list[dict[str, Any]]) -> str:
r"""Generate the system message describing all the available tools."""
...
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
def function_formatter(functions: list["FunctionCall"]) -> str:
r"""Generate the assistant message including all the tool calls."""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
r"""Default tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
r"""GLM-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content:
return content
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
r"""Mistral v0.3 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
r"""Qwen 2.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
tool_match: list[str] = re.findall(regex, content)
if not tool_match:
return content