Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
This commit is contained in:
@@ -168,6 +168,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
for key, value in mm_inputs.items():
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
||||
value = torch.stack([torch.stack(per_value) for per_value in value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
@@ -176,6 +178,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||
gen_kwargs["input_ids"] = inputs
|
||||
del gen_kwargs["image_sizes"]
|
||||
gen_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
@@ -207,6 +214,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
if isinstance(generate_output, tuple):
|
||||
generate_output = generate_output[1][0] # for minicpm_o
|
||||
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = tokenizer.batch_decode(
|
||||
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
||||
|
||||
Reference in New Issue
Block a user