merge model part to the text stream

Former-commit-id: b6fcb832ddaed4647d6f2b926f3dfccd47f3ea84
This commit is contained in:
BUAADreamer
2024-04-25 08:20:41 +08:00
parent 5142349661
commit 00e2a272ef
5 changed files with 24 additions and 172 deletions

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import split_dataset, get_mm_dataset
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_tokenizer, load_processor, load_mm_model
from ...model import load_tokenizer, load_processor, load_model
from ..utils import create_modelcard_and_push
from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer
@@ -29,10 +29,7 @@ def run_sft_mm(
CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
tokenizer.chat_template = CHAT_TEMPLATE
processor.tokenizer = tokenizer
use_clm = True
if "blip" in model_args.model_name_or_path:
use_clm = False
model = load_mm_model(processor, model_args, finetuning_args, training_args.do_train, use_clm=use_clm)
model = load_model(processor.tokenizer, model_args, finetuning_args, training_args.do_train)
dataset = get_mm_dataset(processor, model_args, data_args, training_args, stage="sft")
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction