[da'ta] fix minicpmv plugin (#6890)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv * update readme * support dpo of minicpmv * update init audio * update init audio * [model]fix image process in minicpmo * fix no mm inputs Former-commit-id: cdd19ccd8cec460606b4545e886e932c1c5c5fe1
This commit is contained in:
@@ -106,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_audlens.append(len(audios))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
fake_input_ids = None
|
||||
fake_input_ids = []
|
||||
if (
|
||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
@@ -115,10 +115,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_messages = self.template.mm_plugin.process_messages(
|
||||
fake_messages, fake_images, [], [], self.processor
|
||||
)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
|
||||
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
_fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
|
||||
)
|
||||
fake_input_ids.extend(_fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
|
||||
@@ -130,14 +131,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_messages = self.template.mm_plugin.process_messages(
|
||||
fake_messages, [], [], fake_audios, self.processor
|
||||
)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
|
||||
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
_fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
|
||||
)
|
||||
fake_input_ids.extend(_fake_input_ids)
|
||||
batch_audios = fake_audios
|
||||
batch_audlens[0] = 1
|
||||
|
||||
if fake_input_ids is not None:
|
||||
if len(fake_input_ids) != 0:
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
|
||||
@@ -645,6 +645,12 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
chunk_input=True,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
audio_feature_lens = [
|
||||
torch.tensor(audio_feature_len)
|
||||
if not isinstance(audio_feature_len, torch.Tensor)
|
||||
else audio_feature_len
|
||||
for audio_feature_len in audio_feature_lens
|
||||
]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
|
||||
@@ -982,6 +982,17 @@ _register_template(
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
name="minicpm_o",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user