From 12c51655cebff8e60bff48d092916c6e96348852 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Thu, 25 Apr 2024 00:22:43 +0800 Subject: [PATCH] add llava and instructblip Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06 --- data/mllm_example_dataset/README.md | 25 +++++ data/mllm_example_dataset/data/test-0.parquet | Bin 0 -> 4580 bytes .../mllm_example_dataset/data/train-0.parquet | Bin 0 -> 4580 bytes examples/mllm/sft_instructblip.sh | 16 ++- examples/mllm/{sft_blip2.sh => sft_llava.sh} | 17 ++-- scripts/make_mllm_instruct.py | 95 ++++++++++++++++++ scripts/test_mllm.py | 84 ++++++++++++++++ src/llmtuner/data/loader.py | 3 +- src/llmtuner/data/preprocess.py | 2 +- src/llmtuner/hparams/data_args.py | 4 - src/llmtuner/hparams/model_args.py | 4 - src/llmtuner/model/adapter.py | 22 ++-- src/llmtuner/model/loader.py | 3 +- src/llmtuner/train/sftmm/collator.py | 82 ++++----------- src/llmtuner/train/sftmm/trainer.py | 95 +----------------- src/llmtuner/train/sftmm/workflow.py | 35 +++---- 16 files changed, 273 insertions(+), 214 deletions(-) create mode 100644 data/mllm_example_dataset/README.md create mode 100644 data/mllm_example_dataset/data/test-0.parquet create mode 100644 data/mllm_example_dataset/data/train-0.parquet rename examples/mllm/{sft_blip2.sh => sft_llava.sh} (58%) create mode 100644 scripts/make_mllm_instruct.py create mode 100644 scripts/test_mllm.py diff --git a/data/mllm_example_dataset/README.md b/data/mllm_example_dataset/README.md new file mode 100644 index 00000000..d5c8c0e6 --- /dev/null +++ b/data/mllm_example_dataset/README.md @@ -0,0 +1,25 @@ +--- +dataset_info: + features: + - name: messages + list: + - name: content + list: + - name: index + dtype: int64 + - name: text + dtype: string + - name: type + dtype: string + - name: role + dtype: string + - name: images + sequence: image +configs: +- config_name: default + data_files: + - split: train + path: data/train-* + - split: test + path: data/test-* +--- \ No newline at end of file diff --git a/data/mllm_example_dataset/data/test-0.parquet b/data/mllm_example_dataset/data/test-0.parquet new file mode 100644 index 0000000000000000000000000000000000000000..42c20b192497168523c3d39447cdae4495085b84 GIT binary patch literal 4580 zcmdTIO>84qc_wRnlaN&LHsc*{qws2Fr_nZRJ5HP?tA%tXA^uq>&A)#Tg6ElM&t%6l zW@a46Sr)C3s!AL=RTT%s38AX0RKx*L#Q`A>sHgT&sS*OM5OAnaFCYOb-+N;_KV-Ko zDiw^9H}8G_|8J5_d3#m}2xG!{K^PFYD;zgC!F3;j6XHT@MwkS2NC-`cgFVd2E+S4} z00p|Mr68=q5`I~@OgC>ge)D7d89u~G5BUip5KVKTvFOoYgSalP6E|QQu6LQ3q(Ln2 zvT9o%yw3oGhNY1sVIVRYd6$oTz)OqL;FFjXodG{NVSs~W3|<@u=Z8bWuLlc)$UOfj z@aNnpz>B(#utSiilg{!C0Cr*X902ZMiy&-iDB}?C_%22@$8I16cZg%5^_FF*dVHJ- zz#a-K6G*cStG+xvnhjRwRdBBAU=JY3ws>P6xwfhj2h|K>YQv*1Yle$-vMhmsEP2jF zySm-LF32@a{>IujI8F>Ma-E^YaJf)-?3Sf1CxBETDsH(1>HNlsYZ}_kMpPvX5(fPocD;ve`lr& zf>9H9OdD%_PE-v{v(5kZ4?8-sj&-|*W*1Pya9zW;kYS;G3*wwgrcyJIgO-b`1T@gN zK}H~4j_aY7(Vyz7acGzZwvb(ejM%Lvnxouq6RX!(#^;h~9e4QkvAHu4)iVC5Rj~iDS@$#AeTYHxA`usg7>WF8Sb8`j{q@qsr)&I$Efy1`0_{8)E#Vgf z;2<@Gy2y$TqA3YCPDPWOW|oRUh6b{yv;BC`rk%XZ2bDqT!z=$`6go}9qVNMg@+L1m zxFWqBo>(}i^g=&w2=SgD!Y1_ty3QGbH-@@(J#2a43phIi)5M>bR4gSDhx#Ny9)9r> zticbTqTYxBKKR37>D!UMH`$9vV!*LYYPcBr5g+*(FTE;A?~F6U&uh5Qc$j+yFpnmI z&cAEI zCI&>0-9F_=VBpdeC;jr7$NLg!VoztTyg3m0)`0Z-HR%^o8rw4p9;x#pm!p3lOLJn# zRdv$9LV!yTi2c06JsT?A22;}kk=<}`83Dd``T5knEAW$uLvV{-AAoM5mm?>Pz@GKb zb*@0~@h$+0{tSQ?Qx^c5yZQYjRRLd`pZTt28o92ZNGLiHHUD>R_y4a!`B@LN|MJNB z6z9ih-@cKbKOG54gOS&+z{gy4M*N)Ks@LJ(u3?pkZw=gw8CK1X-9q)uVQOO&BDdBRcWjpRJP<}%1b)=X0=f{%pKTu*f%Q0wWL17zNC-8)&Nj5N{w3etSwT_2 z-=?Jqabcv4{3O!y7dR0$u><4OyQwytH?iZ`ZFEQ+_UH0!I-ZQDq9%Opo%`Wl8HT_5 I;r~1T1e%nlz5oCK literal 0 HcmV?d00001 diff --git a/data/mllm_example_dataset/data/train-0.parquet b/data/mllm_example_dataset/data/train-0.parquet new file mode 100644 index 0000000000000000000000000000000000000000..42c20b192497168523c3d39447cdae4495085b84 GIT binary patch literal 4580 zcmdTIO>84qc_wRnlaN&LHsc*{qws2Fr_nZRJ5HP?tA%tXA^uq>&A)#Tg6ElM&t%6l zW@a46Sr)C3s!AL=RTT%s38AX0RKx*L#Q`A>sHgT&sS*OM5OAnaFCYOb-+N;_KV-Ko zDiw^9H}8G_|8J5_d3#m}2xG!{K^PFYD;zgC!F3;j6XHT@MwkS2NC-`cgFVd2E+S4} z00p|Mr68=q5`I~@OgC>ge)D7d89u~G5BUip5KVKTvFOoYgSalP6E|QQu6LQ3q(Ln2 zvT9o%yw3oGhNY1sVIVRYd6$oTz)OqL;FFjXodG{NVSs~W3|<@u=Z8bWuLlc)$UOfj z@aNnpz>B(#utSiilg{!C0Cr*X902ZMiy&-iDB}?C_%22@$8I16cZg%5^_FF*dVHJ- zz#a-K6G*cStG+xvnhjRwRdBBAU=JY3ws>P6xwfhj2h|K>YQv*1Yle$-vMhmsEP2jF zySm-LF32@a{>IujI8F>Ma-E^YaJf)-?3Sf1CxBETDsH(1>HNlsYZ}_kMpPvX5(fPocD;ve`lr& zf>9H9OdD%_PE-v{v(5kZ4?8-sj&-|*W*1Pya9zW;kYS;G3*wwgrcyJIgO-b`1T@gN zK}H~4j_aY7(Vyz7acGzZwvb(ejM%Lvnxouq6RX!(#^;h~9e4QkvAHu4)iVC5Rj~iDS@$#AeTYHxA`usg7>WF8Sb8`j{q@qsr)&I$Efy1`0_{8)E#Vgf z;2<@Gy2y$TqA3YCPDPWOW|oRUh6b{yv;BC`rk%XZ2bDqT!z=$`6go}9qVNMg@+L1m zxFWqBo>(}i^g=&w2=SgD!Y1_ty3QGbH-@@(J#2a43phIi)5M>bR4gSDhx#Ny9)9r> zticbTqTYxBKKR37>D!UMH`$9vV!*LYYPcBr5g+*(FTE;A?~F6U&uh5Qc$j+yFpnmI z&cAEI zCI&>0-9F_=VBpdeC;jr7$NLg!VoztTyg3m0)`0Z-HR%^o8rw4p9;x#pm!p3lOLJn# zRdv$9LV!yTi2c06JsT?A22;}kk=<}`83Dd``T5knEAW$uLvV{-AAoM5mm?>Pz@GKb zb*@0~@h$+0{tSQ?Qx^c5yZQYjRRLd`pZTt28o92ZNGLiHHUD>R_y4a!`B@LN|MJNB z6z9ih-@cKbKOG54gOS&+z{gy4M*N)Ks@LJ(u3?pkZw=gw8CK1X-9q)uVQOO&BDdBRcWjpRJP<}%1b)=X0=f{%pKTu*f%Q0wWL17zNC-8)&Nj5N{w3etSwT_2 z-=?Jqabcv4{3O!y7dR0$u><4OyQwytH?iZ`ZFEQ+_UH0!I-ZQDq9%Opo%`Wl8HT_5 I;r~1T1e%nlz5oCK literal 0 HcmV?d00001 diff --git a/examples/mllm/sft_instructblip.sh b/examples/mllm/sft_instructblip.sh index 92478500..b3923655 100644 --- a/examples/mllm/sft_instructblip.sh +++ b/examples/mllm/sft_instructblip.sh @@ -3,20 +3,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft_mm \ --do_train \ - --model_name_or_path /home/LAB/fengzc/LLM/checkpoints/Salesforce/instructblip-vicuna-7b \ - --dataset llava_instruct_100 \ + --model_name_or_path Salesforce/instructblip-vicuna-7b \ + --dataset mllm_instruct_example \ --dataset_dir data \ --template default \ --finetuning_type lora \ - --lora_target q_proj,k_proj \ + --lora_target all \ --output_dir saves/instructblip-vicuna-7b/lora/sft \ --overwrite_cache \ --overwrite_output_dir \ --cutoff_len 1024 \ --preprocessing_num_workers 16 \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 3 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 8 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 1 \ --warmup_steps 20 \ @@ -25,10 +25,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --evaluation_strategy steps \ --load_best_model_at_end \ --learning_rate 1e-5 \ - --num_train_epochs 3.0 \ + --num_train_epochs 50 \ --max_samples 3000 \ --val_size 0.1 \ --plot_loss \ - --quantization_bit 8 \ - --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 \ - --use_qformer \ No newline at end of file + --bf16 \ No newline at end of file diff --git a/examples/mllm/sft_blip2.sh b/examples/mllm/sft_llava.sh similarity index 58% rename from examples/mllm/sft_blip2.sh rename to examples/mllm/sft_llava.sh index ac0a3f11..c1fce693 100644 --- a/examples/mllm/sft_blip2.sh +++ b/examples/mllm/sft_llava.sh @@ -3,20 +3,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft_mm \ --do_train \ - --model_name_or_path /home/LAB/fengzc/LLM/checkpoints/Salesforce/blip2-opt-2.7b \ - --dataset llava_instruct_100 \ + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --dataset mllm_instruct_example \ --dataset_dir data \ --template default \ --finetuning_type lora \ - --lora_target q_proj,k_proj \ - --output_dir saves/blip2-opt-2.7b/lora/sft \ + --lora_target all \ + --output_dir saves/llava-1.5-7b/lora/sft \ --overwrite_cache \ --overwrite_output_dir \ --cutoff_len 1024 \ --preprocessing_num_workers 16 \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 3 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 8 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 1 \ --warmup_steps 20 \ @@ -25,9 +25,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --evaluation_strategy steps \ --load_best_model_at_end \ --learning_rate 5e-5 \ - --num_train_epochs 3.0 \ + --num_train_epochs 100 \ --max_samples 3000 \ --val_size 0.1 \ --plot_loss \ - --quantization_bit 8 \ - --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 \ No newline at end of file + --bf16 \ No newline at end of file diff --git a/scripts/make_mllm_instruct.py b/scripts/make_mllm_instruct.py new file mode 100644 index 00000000..41e13b8e --- /dev/null +++ b/scripts/make_mllm_instruct.py @@ -0,0 +1,95 @@ +import json +import os.path + +import fire +from datasets import Dataset, concatenate_datasets, load_dataset, Value, Image, Features, Sequence + +"""usage +python3 scripts/make_mllm_instruct.py \ +--json_path data/llava_instruct_example.json \ +--image_path data/images \ +--output_path data/mllm_example_dataset +""" + + +def make_one_json(json_path, image_path) -> Dataset: + with open(json_path) as f: + raw_data_ls = json.loads(f.read()) + data_ls = [] + for i, data in enumerate(raw_data_ls): + for j, message in enumerate(data['messages']): + text = message['content'] + message['content'] = [{'index': None, 'text': text, 'type': 'text'}] + if j == 0: + message['content'].append({'index': 0, 'text': None, 'type': 'image'}) + image = data['image'] + if image_path: + image = os.path.join(image_path, data['image']) + data['images'] = [image] + del data['image'] + data_ls.append(data) + + def gen(): + for data in data_ls: + yield data + + features = Features({'messages': [{'content': [ + {'index': Value(dtype='int64', id=None), 'text': Value(dtype='string', id=None), + 'type': Value(dtype='string', id=None)}], 'role': Value(dtype='string', id=None)}], + 'images': Sequence(feature=Image(decode=True, id=None), length=-1, id=None)}) + dataset = Dataset.from_generator(gen, features=features) + return dataset + + +yaml_content = """--- +dataset_info: + features: + - name: messages + list: + - name: content + list: + - name: index + dtype: int64 + - name: text + dtype: string + - name: type + dtype: string + - name: role + dtype: string + - name: images + sequence: image +configs: +- config_name: default + data_files: + - split: train + path: data/train-* + - split: test + path: data/test-* +---""" + + +def main( + json_path: str, + image_path: str, + output_path: str, +): + json_path_list = json_path.split() + dataset_list = [] + for json_path in json_path_list: + dataset = make_one_json(json_path, image_path) + dataset_list.append(dataset) + dataset = concatenate_datasets(dataset_list) + print(dataset[0]) + data_path = os.path.join(output_path, "data") + os.makedirs(os.path.join(data_path), exist_ok=True) + parquet_path = os.path.join(data_path, "train-0.parquet") + dataset.to_parquet(parquet_path) + parquet_path = os.path.join(data_path, "test-0.parquet") + dataset.to_parquet(parquet_path) + readme_path = os.path.join(output_path, "README.md") + with open(readme_path, 'w') as f: + f.write(yaml_content) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/scripts/test_mllm.py b/scripts/test_mllm.py new file mode 100644 index 00000000..c03525b8 --- /dev/null +++ b/scripts/test_mllm.py @@ -0,0 +1,84 @@ +import os.path + +import fire +import torch +from datasets import load_dataset +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor + +"""usage +python3 scripts/test_mllm.py \ +--base_model_path llava-hf/llava-1.5-7b-hf \ +--lora_model_path saves/llava-1.5-7b/lora/sft \ +--model_path saves/llava-1.5-7b/lora/merged \ +--dataset_name data/mllm_example_dataset \ +--do_merge 1 +""" + + +def get_processor(model_path): + CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + tokenizer.chat_template = CHAT_TEMPLATE + processor = AutoProcessor.from_pretrained(model_path) + processor.tokenizer = tokenizer + return processor + + +def apply_lora(base_model_path, model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base_model = AutoModelForVision2Seq.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda", + ) + processor = get_processor(base_model_path) + tokenizer = processor.tokenizer + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base_model, + lora_path, + torch_dtype=torch.float16, + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {model_path}") + model.save_pretrained(model_path) + tokenizer.save_pretrained(model_path) + processor.image_processor.save_pretrained(model_path) + + +def main( + model_path: str, + dataset_name: str, + base_model_path: str = "", + lora_model_path: str = "", + do_merge: bool = False, +): + if not os.path.exists(model_path) or do_merge: + apply_lora(base_model_path, model_path, lora_model_path) + model = AutoModelForVision2Seq.from_pretrained( + model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda" + ) + processor = get_processor(model_path) + raw_datasets = load_dataset(dataset_name) + train_dataset = raw_datasets['train'] + examples = train_dataset.select(range(3)) + texts = [] + images = [] + for example in examples: + messages = example["messages"][:1] + text = processor.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + texts.append(text) + images.append(example["images"][0]) + batch = processor(texts, images, return_tensors="pt", padding=True).to("cuda") + output = model.generate(**batch, max_new_tokens=100) + res = processor.batch_decode(output, skip_special_tokens=True) + print(res) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index b3af434b..18665731 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -199,8 +199,7 @@ def get_mm_dataset( with training_args.main_process_first(desc="load dataset"): all_datasets = [] for dataset_attr in get_dataset_list(data_args): - local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) - all_datasets.append(load_dataset("json", data_files=local_path)['train']) + all_datasets.append(load_dataset(dataset_attr.dataset_name)['train']) dataset = merge_dataset(all_datasets, data_args, training_args) return dataset diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index b8edfa10..8494ba7e 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -275,4 +275,4 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - return preprocess_func, print_function + return preprocess_func, print_function \ No newline at end of file diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 3b52f1ea..f5f75c77 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -88,10 +88,6 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) - image_path: Optional[str] = field( - default=None, - metadata={"help": "Path to images."}, - ) def __post_init__(self): if self.reserved_label_len >= self.cutoff_len: diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 32637f59..0e42033f 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -165,10 +165,6 @@ class ModelArguments: default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) - use_qformer: bool = field( - default=False, - metadata={"help": "Whether use qformer for Multimodal LLM."}, - ) def __post_init__(self): self.compute_dtype = None diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 624d8a85..e66a984b 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -182,7 +182,8 @@ def init_adapter( def init_mm_adapter( model: "AutoModelForVision2Seq", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - is_trainable: bool + is_trainable: bool, + use_clm=True, ) -> "AutoModelForVision2Seq": if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) @@ -253,12 +254,19 @@ def init_mm_adapter( } model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) else: - lora_config = LoraConfig( - # task_type=TaskType.CAUSAL_LM, - inference_mode=False, - use_dora=finetuning_args.use_dora, - **peft_kwargs, - ) + if use_clm: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + use_dora=finetuning_args.use_dora, + **peft_kwargs, + ) + else: + lora_config = LoraConfig( + inference_mode=False, + use_dora=finetuning_args.use_dora, + **peft_kwargs, + ) model = get_peft_model(model, lora_config) if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index eeee69a6..917f11c9 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -191,6 +191,7 @@ def load_mm_model( finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, + use_clm=True, ) -> "AutoModelForVision2Seq": r""" Loads pretrained model. Must after load_tokenizer. @@ -231,7 +232,7 @@ def load_mm_model( patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) - model = init_mm_adapter(model, model_args, finetuning_args, is_trainable) + model = init_mm_adapter(model, model_args, finetuning_args, is_trainable, use_clm) if not is_trainable: model.requires_grad_(False) diff --git a/src/llmtuner/train/sftmm/collator.py b/src/llmtuner/train/sftmm/collator.py index e91374bc..95dbd939 100644 --- a/src/llmtuner/train/sftmm/collator.py +++ b/src/llmtuner/train/sftmm/collator.py @@ -1,69 +1,29 @@ -import json -import os from dataclasses import dataclass - -import torch -from torch.utils.data import Dataset as Dataset_torch -from datasets import Dataset -from PIL import Image from transformers import AutoProcessor -class ImageCaptioningDataset(Dataset_torch): - def __init__(self, dataset: Dataset, image_path: str, processor: AutoProcessor): - self.processor = processor - self.dataset = dataset - self.image_path = image_path - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - source = self.dataset[idx] - image_id = source['image'] - image = Image.open(os.path.join(self.image_path, image_id)) - convs = source['conversations'] - prompt = convs[0]['value'] - label = convs[1]['value'] - image_inputs = self.processor(image, return_tensors="pt") - image_inputs = {k: v.squeeze() for k, v in image_inputs.items()} - inputs = { - "input_ids": prompt, - "labels": label, - } - for key in image_inputs: - inputs[key] = image_inputs[key] - return inputs - - @dataclass class DataCollatorForVis2Seq: processor: AutoProcessor - use_qformer: bool = False - def __call__(self, features, return_tensors=None): - processed_batch = {} - for key in features[0].keys(): - if key == 'pixel_values': - processed_batch[key] = torch.stack([example[key] for example in features]) - elif key == 'input_ids': - text_inputs = self.processor.tokenizer( - [example[key] for example in features], padding="max_length", return_tensors="pt", - max_length=512, - ) - processed_batch["input_ids"] = text_inputs["input_ids"] - processed_batch["attention_mask"] = text_inputs["attention_mask"] - if self.use_qformer: - qformer_text_inputs = self.processor.qformer_tokenizer( - [example[key] for example in features], padding="max_length", return_tensors="pt", - max_length=512, - ) - processed_batch["qformer_input_ids"] = qformer_text_inputs["input_ids"] - processed_batch["qformer_attention_mask"] = qformer_text_inputs["attention_mask"] - elif key == 'labels': - text_inputs = self.processor.tokenizer( - [example[key] for example in features], padding="max_length", return_tensors="pt", - max_length=512, - ) - processed_batch["labels"] = text_inputs["input_ids"] - return processed_batch + def __call__(self, examples): + texts = [] + images = [] + for example in examples: + if len(example["images"]) > 1: + raise ValueError("This collator only supports one image per example") + messages = example["messages"] + text = self.processor.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + texts.append(text) + images.append(example["images"][0]) + + batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) + + labels = batch["input_ids"].clone() + if self.processor.tokenizer.pad_token_id is not None: + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch diff --git a/src/llmtuner/train/sftmm/trainer.py b/src/llmtuner/train/sftmm/trainer.py index 96b86b44..f094e609 100644 --- a/src/llmtuner/train/sftmm/trainer.py +++ b/src/llmtuner/train/sftmm/trainer.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import Seq2SeqTrainer +from transformers import Seq2SeqTrainer, Trainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -32,23 +32,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - # def compute_loss(self, model, inputs, return_outputs=False): - # print(inputs.keys()) - # device = "cuda" - # input_ids = inputs.get("input_ids").to(device) - # pixel_values = inputs.get("pixel_values").to(device, torch.float16) - # attention_mask = inputs.get("attention_mask").to(device) - # labels = inputs.get("labels").to(device) - # - # outputs = model(input_ids=input_ids, - # pixel_values=pixel_values, - # labels=labels, - # # attention_mask=attention_mask, - # ) - # loss = outputs.loss - # print("Loss:", loss.item()) - # return (loss, outputs) if return_outputs else loss - def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) @@ -59,79 +42,3 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) -> "torch.optim.lr_scheduler.LRScheduler": create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - - def prediction_step( - self, - model: "torch.nn.Module", - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, - ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: - r""" - Removes the prompt part in the generated tokens. - - Subclass and override to inject custom behavior. - """ - labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels - if self.args.predict_with_generate: - assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." - prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) - if prompt_len > label_len: - inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) - if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) - inputs["labels"] = inputs["labels"][:, :prompt_len] - - loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) - model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys - ) - if generated_tokens is not None and self.args.predict_with_generate: - generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id - generated_tokens = generated_tokens.contiguous() - - return loss, generated_tokens, labels - - def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: - r""" - Pads the tensor to the same length as the target tensor. - """ - assert self.tokenizer.pad_token_id is not None, "Pad token is required." - padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) - padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding - return padded_tensor.contiguous() # in contiguous memory - - def save_predictions(self, predict_results: "PredictionOutput") -> None: - r""" - Saves model predictions to `output_dir`. - - A custom behavior that not contained in Seq2SeqTrainer. - """ - if not self.is_world_process_zero(): - return - - output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - logger.info(f"Saving prediction results to {output_prediction_file}") - - labels = np.where( - predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id - ) - preds = np.where( - predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id - ) - - for i in range(len(preds)): - pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] - if len(pad_len): - preds[i] = np.concatenate( - (preds[i][pad_len[0]:], preds[i][: pad_len[0]]), axis=-1 - ) # move pad token to last - - decoded_labels = self.tokenizer.batch_decode( - labels, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) - - with open(output_prediction_file, "w", encoding="utf-8") as writer: - res: List[str] = [] - for label, pred in zip(decoded_labels, decoded_preds): - res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) - writer.write("\n".join(res)) diff --git a/src/llmtuner/train/sftmm/workflow.py b/src/llmtuner/train/sftmm/workflow.py index 9f952772..21f4aebf 100644 --- a/src/llmtuner/train/sftmm/workflow.py +++ b/src/llmtuner/train/sftmm/workflow.py @@ -1,21 +1,14 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py import os from typing import TYPE_CHECKING, List, Optional - -import torch -from PIL import Image -from torch.utils.data import Dataset -from transformers import DataCollatorForSeq2Seq, LlavaNextForConditionalGeneration, AutoModelForVision2Seq - from ...data import split_dataset, get_mm_dataset -from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss -from ...model import load_model, load_tokenizer, load_processor, load_mm_model +from ...model import load_tokenizer, load_processor, load_mm_model from ..utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer -from .collator import DataCollatorForVis2Seq, ImageCaptioningDataset +from .collator import DataCollatorForVis2Seq if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -32,28 +25,27 @@ def run_sft_mm( callbacks: Optional[List["TrainerCallback"]] = None, ): processor = load_processor(model_args) - tokenizer = processor.tokenizer - model = load_mm_model(processor, model_args, finetuning_args, training_args.do_train) + tokenizer = load_tokenizer(model_args) + CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + tokenizer.chat_template = CHAT_TEMPLATE + processor.tokenizer = tokenizer + use_clm = True + if "blip" in model_args.model_name_or_path: + use_clm = False + model = load_mm_model(processor, model_args, finetuning_args, training_args.do_train, use_clm=use_clm) dataset = get_mm_dataset(processor, model_args, data_args, training_args, stage="sft") - if training_args.predict_with_generate: - tokenizer.padding_side = "left" # use left-padding in generation if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - splited_dataset = split_dataset(dataset, data_args, training_args) - splited_dataset['train_dataset'].set_format(type=splited_dataset['train_dataset'].format["type"], - columns=list(splited_dataset['train_dataset'].features.keys())) - splited_dataset['eval_dataset'].set_format(type=splited_dataset['eval_dataset'].format["type"], - columns=list(splited_dataset['eval_dataset'].features.keys())) - train_dataset = ImageCaptioningDataset(splited_dataset['train_dataset'], data_args.image_path, processor) - eval_dataset = ImageCaptioningDataset(splited_dataset['eval_dataset'], data_args.image_path, processor) + train_dataset = dataset + eval_dataset = dataset data_collator = DataCollatorForVis2Seq( processor=processor, - use_qformer=model_args.use_qformer, ) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False # Initialize our Trainer trainer = CustomSeq2SeqTrainer( @@ -67,7 +59,6 @@ def run_sft_mm( train_dataset=train_dataset, eval_dataset=eval_dataset, ) - # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids