refactor mm training

Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 77c2c7076b
commit c62a6ca59d
29 changed files with 499 additions and 312 deletions

View File

@@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
from .collator import (
CustomDataCollatorForSeq2Seq,
KTODataCollatorWithPadding,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"CustomDataCollatorForSeq2Seq",
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",

View File

@@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
Data collator for custom models (like Qwen2-VL).
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
if "image_grid_thw" in features[0]:
@@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, 0)
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
# Handle the case where the list is empty, for example:
image_grid_thw = None
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, 0)
else:
# Handle the case where the list is empty, for example:
pixel_values = None
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
@@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
@@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
@@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]

View File

@@ -0,0 +1,271 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from PIL.Image import Image
from transformers import ProcessorMixin
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
import torch
from PIL import Image
if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
r"""
Processes visual inputs. (currently only supports a single image)
Returns:
pixel_values: tensor with shape (B, C, H, W)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor([image], return_tensors="pt")["pixel_values"]
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_seq_length = getattr(processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
def get_qwen2vl_image_inputs(
images: Sequence["ImageObject"], processor: "ProcessorMixin"
) -> Dict[str, "torch.Tensor"]:
r"""
Processes qwen2-vl visual inputs. Supports multiple images.
Returns:
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0 # fake image
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
class BasePlugin:
def __init__(self, image_token: str) -> None:
self.image_token = image_token
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return {}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
return
class LlavaPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
new_messages.append({"role": message["role"], "content": content})
return new_messages
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return {"pixel_values": get_pixel_values(images, processor)}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
class PaliGemmaPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("PaliGemma model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
new_messages.append({"role": message["role"], "content": content})
return new_messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seq_length + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seq_length + labels
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
return mm_inputs
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
for feature_name in feature_seqlens.keys():
model_inputs[feature_name].append(mm_inputs[feature_name][0])
class Qwen2vlPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[index].prod() // merge_length)
),
1,
)
index += 1
new_messages.append({"role": message["role"], "content": content})
return new_messages
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return get_qwen2vl_image_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
PLUGINS = {
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
}
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
if name not in PLUGINS:
raise ValueError("{} not found.".format(name))
return PLUGINS[name](image_token)

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
@@ -40,9 +41,6 @@ def _encode_feedback_example(
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@@ -62,10 +60,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@@ -91,28 +87,15 @@ def preprocess_feedback_dataset(
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
@@ -129,11 +112,15 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
@@ -39,9 +40,6 @@ def _encode_pairwise_example(
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
@@ -51,10 +49,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@@ -77,27 +72,15 @@ def preprocess_pairwise_dataset(
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
@@ -112,15 +95,15 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return model_inputs

View File

@@ -13,20 +13,7 @@
# limitations under the License.
import bisect
from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from typing import List, Sequence, Tuple
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
@@ -61,37 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. support multi images
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.

View File

@@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import (
get_paligemma_token_type_ids,
get_pixel_values,
get_qwen2vl_image_inputs,
greedy_knapsack,
infer_seqlen,
)
from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -43,41 +36,15 @@ def _encode_supervised_example(
system: Optional[str],
tools: Optional[str],
template: "Template",
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_processor = getattr(processor, "image_processor")
merge_length = image_processor.merge_size**2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
for message in prompt:
content = message["content"]
while "<|image_pad|>" in content:
content = content.replace(
"<|image_pad|>",
template.vision_start_token
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ template.vision_end_token,
1,
)
index += 1
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
@@ -125,28 +92,21 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
model_inputs["image_grid_thw"] = []
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@@ -157,15 +117,12 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
else:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
return model_inputs
@@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset(
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
@@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset(
batch_labels.append(labels)
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
@@ -39,9 +40,6 @@ def _encode_unsupervised_example(
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1:
messages = prompt + response
else:
@@ -51,10 +49,7 @@ def _encode_unsupervised_example(
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
@@ -69,19 +64,15 @@ def preprocess_unsupervised_dataset(
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
@@ -93,10 +84,12 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
return model_inputs

View File

@@ -15,9 +15,11 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import BasePlugin, get_mm_plugin
if TYPE_CHECKING:
@@ -41,11 +43,9 @@ class Template:
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
vision_start_token: str
vision_end_token: str
efficient_eos: bool
replace_eos: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
self,
@@ -207,11 +207,9 @@ def _register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
image_token: str = "<image>",
vision_start_token: str = "<|vision_start|>",
vision_end_token: str = "<|vision_end|>",
efficient_eos: bool = False,
replace_eos: bool = False,
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
) -> None:
r"""
Registers a chat template.
@@ -258,11 +256,9 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
vision_start_token=vision_start_token,
vision_end_token=vision_end_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
mm_plugin=mm_plugin,
)
@@ -722,6 +718,17 @@ _register_template(
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -766,6 +773,19 @@ _register_template(
)
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -790,17 +810,15 @@ _register_template(
_register_template(
name="qwen2vl",
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
image_token="<|image_pad|>",
vision_start_token="<|vision_start|>",
vision_end_token="<|vision_end|>",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
)
@@ -915,6 +933,7 @@ _register_template(
),
stop_words=["###"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)