merge data part to the text stream

Former-commit-id: 80537d580119d9d5a06ab236a5284aaae2f83b5b
This commit is contained in:
BUAADreamer
2024-04-25 19:58:47 +08:00
parent 3c792174db
commit d1d08d066a
5 changed files with 18 additions and 35 deletions

View File

@@ -6,22 +6,23 @@ from datasets import load_dataset
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
import shutil
from PIL import Image
"""usage
python3 scripts/test_mllm.py \
--base_model_path llava-hf/llava-1.5-7b-hf \
--lora_model_path saves/llava-1.5-7b/lora/sft \
--model_path saves/llava-1.5-7b/lora/merged \
--dataset_name data/mllm_example_dataset \
--dataset_name data/llava_instruct_example.json \
--do_merge 1
"""
def get_processor(model_path):
CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
processor = AutoProcessor.from_pretrained(model_path)
CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }} ASSISTANT: {% else %}{{ message['content'] }}{% endif %} {% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
tokenizer.chat_template = CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_path)
processor.tokenizer = tokenizer
return processor
@@ -69,7 +70,7 @@ def main(
device_map="cuda",
)
processor = get_processor(model_path)
raw_datasets = load_dataset(dataset_name)
raw_datasets = load_dataset("json", data_files=dataset_name)
train_dataset = raw_datasets["train"]
examples = train_dataset.select(range(3))
texts = []
@@ -80,11 +81,18 @@ def main(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = processor(texts, images, return_tensors="pt", padding=True).to("cuda")
images.append(Image.open(example["images"][0]))
batch = processor(text=texts, images=images, return_tensors="pt", padding=True).to(
"cuda"
)
output = model.generate(**batch, max_new_tokens=100)
res = processor.batch_decode(output, skip_special_tokens=True)
print(res)
res_list = processor.batch_decode(output, skip_special_tokens=True)
for i, prompt in enumerate(texts):
res = res_list[i]
print(f"#{i}")
print(f"prompt:{prompt}")
print(f"response:{res[len(prompt):].strip()}")
print()
if __name__ == "__main__":