support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
@@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
|
||||
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
@@ -71,7 +71,9 @@ def calculate_lr(
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
|
||||
@@ -16,16 +16,25 @@ import json
|
||||
|
||||
import fire
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.misc import get_device_count
|
||||
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def vllm_infer(
|
||||
model_name_or_path: str,
|
||||
adapter_name_or_path: str = None,
|
||||
@@ -64,15 +73,29 @@ def vllm_infer(
|
||||
)
|
||||
)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
|
||||
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
|
||||
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
|
||||
|
||||
inputs, prompts, labels = [], [], []
|
||||
for sample in dataset:
|
||||
inputs.append({"prompt_token_ids": sample["input_ids"]})
|
||||
for sample in dataset_module["train_dataset"]:
|
||||
if sample["images"]:
|
||||
multi_modal_data = {"image": []}
|
||||
for image in sample["images"]:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data["image"].append(image)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
||||
labels.append(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
||||
@@ -100,6 +123,9 @@ def vllm_infer(
|
||||
"disable_log_stats": True,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
}
|
||||
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user