Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)

* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
This commit is contained in:
Zhangchi Feng
2025-01-14 17:34:58 +08:00
committed by GitHub
parent d0da6f40b0
commit f7857c83e1
4 changed files with 25 additions and 1 deletions

View File

@@ -168,6 +168,8 @@ class HuggingfaceEngine(BaseEngine):
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
@@ -176,6 +178,11 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer
return gen_kwargs, prompt_length
@staticmethod
@@ -207,6 +214,9 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True