fix mixed mm inputs and rlhf-v

Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent 1d8e9c7897
commit 7e4c5d4bb3
20 changed files with 306 additions and 277 deletions

View File

@@ -13,8 +13,8 @@
# limitations under the License.
from .collator import (
CustomDataCollatorForSeq2Seq,
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
@@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"CustomDataCollatorForSeq2Seq",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",

View File

@@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator for custom models (like Qwen2-VL).
Data collator that supports VLMs.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None # TODO: better handle various VLMs
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
image_grid_thw = None
pixel_values = None
if "token_type_ids" in features[0].keys():
for feature in features:
feature["token_type_ids"] = feature["token_type_ids"][0]
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
extra_features = {}
if "pixel_values" in features[0].keys():
pixel_values = []
for feature in features:
if feature["pixel_values"] is None:
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
features = super().__call__(features)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if extra_features["pixel_values"].numel() == 0:
extra_features["pixel_values"] = None
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None})
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
@@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""

View File

@@ -16,16 +16,16 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
from .tool_utils import get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -81,12 +81,7 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
@@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.tool_utils = get_tool_utils(self.tool_format)
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content)
return self.tool_utils.tool_extractor(content)

View File

@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from PIL.Image import Image
@@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
else: # add NoneType for fake images
image = Image.new("RGB", (64, 64), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
image_inputs["image_grid_thw"][0][0] = 0
image_inputs = {key: None for key in image_inputs.keys()}
return image_inputs
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
def _get_paligemma_token_type_ids(
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
class BasePlugin:
@@ -74,6 +76,7 @@ class BasePlugin:
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
@@ -93,18 +96,6 @@ class BasePlugin:
"""
return {}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
r"""
Appends multimodal inputs to model inputs for VLMs.
"""
return
class LlavaPlugin(BasePlugin):
def process_messages(
@@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin):
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
num_images = 0
image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
num_images += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace("{{image}}", self.image_token)
new_messages.append({"role": message["role"], "content": content})
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
return new_messages
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
self,
@@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin):
) -> Dict[str, Any]:
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class PaliGemmaPlugin(BasePlugin):
def process_messages(
@@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin):
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
num_images = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("PaliGemma model only accepts one image per sample.")
num_images += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
message["content"] = content.replace("{{image}}", "")
new_messages.append({"role": message["role"], "content": content})
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return new_messages
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seq_length + input_ids
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seq_length + labels
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin):
) -> Dict[str, Any]:
mm_inputs = _get_mm_inputs(images, processor)
for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
return mm_inputs
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class Qwen2vlPlugin(BasePlugin):
def process_messages(
@@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin):
if len(images) > 0:
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
index = 0
new_messages = []
num_images = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[index].prod() // merge_length)
self.image_token * (image_grid_thw[num_images].prod() // merge_length)
),
1,
)
index += 1
num_images += 1
new_messages.append({"role": message["role"], "content": content})
message["content"] = content
return new_messages
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
self,
@@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin):
) -> Dict[str, Any]:
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value) # support multi-image
PLUGINS = {
"base": BasePlugin,
@@ -270,7 +232,8 @@ PLUGINS = {
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
if name not in PLUGINS:
raise ValueError("{} not found.".format(name))
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
return PLUGINS[name](image_token)
return plugin_class(image_token)

View File

@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:

View File

@@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -36,11 +37,12 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@@ -53,6 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@@ -60,8 +64,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@@ -74,8 +78,15 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images,
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
def preprocess_feedback_dataset(
@@ -93,13 +104,13 @@ def preprocess_feedback_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=prompt,
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@@ -112,15 +123,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -35,13 +36,14 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@@ -49,7 +51,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@@ -60,8 +62,15 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images,
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
def preprocess_pairwise_dataset(
@@ -78,12 +87,12 @@ def preprocess_pairwise_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=prompt,
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@@ -95,15 +104,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs

View File

@@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -35,19 +36,18 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = prompt + response
input_ids, labels = [], []
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
) -> Tuple[List[int], List[int], Dict[str, Any]]:
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
@@ -83,7 +83,10 @@ def _encode_supervised_example(
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_supervised_dataset(
@@ -101,12 +104,12 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_supervised_example(
prompt=prompt,
input_ids, labels, extra_inputs = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@@ -117,12 +120,8 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs
@@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
if processor is not None:
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
@@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset(
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=None,

View File

@@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -35,25 +36,30 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
) -> Tuple[List[int], List[int], Dict[str, Any]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_unsupervised_dataset(
@@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_unsupervised_example(
prompt=prompt,
input_ids, labels, extra_inputs = _encode_unsupervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
@@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name == "qwen2_vl":
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if name is None:
template = TEMPLATES["empty"] # placeholder
else:
@@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words
if template.replace_eos:

View File

@@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils):
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils