@@ -105,7 +105,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
|
||||
)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
@@ -116,6 +119,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
||||
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
|
||||
Reference in New Issue
Block a user