add llava and instructblip
Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06
This commit is contained in:
95
scripts/make_mllm_instruct.py
Normal file
95
scripts/make_mllm_instruct.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import json
|
||||
import os.path
|
||||
|
||||
import fire
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset, Value, Image, Features, Sequence
|
||||
|
||||
"""usage
|
||||
python3 scripts/make_mllm_instruct.py \
|
||||
--json_path data/llava_instruct_example.json \
|
||||
--image_path data/images \
|
||||
--output_path data/mllm_example_dataset
|
||||
"""
|
||||
|
||||
|
||||
def make_one_json(json_path, image_path) -> Dataset:
|
||||
with open(json_path) as f:
|
||||
raw_data_ls = json.loads(f.read())
|
||||
data_ls = []
|
||||
for i, data in enumerate(raw_data_ls):
|
||||
for j, message in enumerate(data['messages']):
|
||||
text = message['content']
|
||||
message['content'] = [{'index': None, 'text': text, 'type': 'text'}]
|
||||
if j == 0:
|
||||
message['content'].append({'index': 0, 'text': None, 'type': 'image'})
|
||||
image = data['image']
|
||||
if image_path:
|
||||
image = os.path.join(image_path, data['image'])
|
||||
data['images'] = [image]
|
||||
del data['image']
|
||||
data_ls.append(data)
|
||||
|
||||
def gen():
|
||||
for data in data_ls:
|
||||
yield data
|
||||
|
||||
features = Features({'messages': [{'content': [
|
||||
{'index': Value(dtype='int64', id=None), 'text': Value(dtype='string', id=None),
|
||||
'type': Value(dtype='string', id=None)}], 'role': Value(dtype='string', id=None)}],
|
||||
'images': Sequence(feature=Image(decode=True, id=None), length=-1, id=None)})
|
||||
dataset = Dataset.from_generator(gen, features=features)
|
||||
return dataset
|
||||
|
||||
|
||||
yaml_content = """---
|
||||
dataset_info:
|
||||
features:
|
||||
- name: messages
|
||||
list:
|
||||
- name: content
|
||||
list:
|
||||
- name: index
|
||||
dtype: int64
|
||||
- name: text
|
||||
dtype: string
|
||||
- name: type
|
||||
dtype: string
|
||||
- name: role
|
||||
dtype: string
|
||||
- name: images
|
||||
sequence: image
|
||||
configs:
|
||||
- config_name: default
|
||||
data_files:
|
||||
- split: train
|
||||
path: data/train-*
|
||||
- split: test
|
||||
path: data/test-*
|
||||
---"""
|
||||
|
||||
|
||||
def main(
|
||||
json_path: str,
|
||||
image_path: str,
|
||||
output_path: str,
|
||||
):
|
||||
json_path_list = json_path.split()
|
||||
dataset_list = []
|
||||
for json_path in json_path_list:
|
||||
dataset = make_one_json(json_path, image_path)
|
||||
dataset_list.append(dataset)
|
||||
dataset = concatenate_datasets(dataset_list)
|
||||
print(dataset[0])
|
||||
data_path = os.path.join(output_path, "data")
|
||||
os.makedirs(os.path.join(data_path), exist_ok=True)
|
||||
parquet_path = os.path.join(data_path, "train-0.parquet")
|
||||
dataset.to_parquet(parquet_path)
|
||||
parquet_path = os.path.join(data_path, "test-0.parquet")
|
||||
dataset.to_parquet(parquet_path)
|
||||
readme_path = os.path.join(output_path, "README.md")
|
||||
with open(readme_path, 'w') as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(main)
|
||||
84
scripts/test_mllm.py
Normal file
84
scripts/test_mllm.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os.path
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import PeftModel
|
||||
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
"""usage
|
||||
python3 scripts/test_mllm.py \
|
||||
--base_model_path llava-hf/llava-1.5-7b-hf \
|
||||
--lora_model_path saves/llava-1.5-7b/lora/sft \
|
||||
--model_path saves/llava-1.5-7b/lora/merged \
|
||||
--dataset_name data/mllm_example_dataset \
|
||||
--do_merge 1
|
||||
"""
|
||||
|
||||
|
||||
def get_processor(model_path):
|
||||
CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
||||
tokenizer.chat_template = CHAT_TEMPLATE
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
processor.tokenizer = tokenizer
|
||||
return processor
|
||||
|
||||
|
||||
def apply_lora(base_model_path, model_path, lora_path):
|
||||
print(f"Loading the base model from {base_model_path}")
|
||||
base_model = AutoModelForVision2Seq.from_pretrained(
|
||||
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda",
|
||||
)
|
||||
processor = get_processor(base_model_path)
|
||||
tokenizer = processor.tokenizer
|
||||
print(f"Loading the LoRA adapter from {lora_path}")
|
||||
|
||||
lora_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
lora_path,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
print("Applying the LoRA")
|
||||
model = lora_model.merge_and_unload()
|
||||
|
||||
print(f"Saving the target model to {model_path}")
|
||||
model.save_pretrained(model_path)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
processor.image_processor.save_pretrained(model_path)
|
||||
|
||||
|
||||
def main(
|
||||
model_path: str,
|
||||
dataset_name: str,
|
||||
base_model_path: str = "",
|
||||
lora_model_path: str = "",
|
||||
do_merge: bool = False,
|
||||
):
|
||||
if not os.path.exists(model_path) or do_merge:
|
||||
apply_lora(base_model_path, model_path, lora_model_path)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda"
|
||||
)
|
||||
processor = get_processor(model_path)
|
||||
raw_datasets = load_dataset(dataset_name)
|
||||
train_dataset = raw_datasets['train']
|
||||
examples = train_dataset.select(range(3))
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
messages = example["messages"][:1]
|
||||
text = processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
texts.append(text)
|
||||
images.append(example["images"][0])
|
||||
batch = processor(texts, images, return_tensors="pt", padding=True).to("cuda")
|
||||
output = model.generate(**batch, max_new_tokens=100)
|
||||
res = processor.batch_decode(output, skip_special_tokens=True)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(main)
|
||||
Reference in New Issue
Block a user