[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -16,7 +16,14 @@ import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
@@ -142,6 +149,8 @@ def load_model(
|
||||
else:
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
|
||||
@@ -280,6 +280,12 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_audio",
|
||||
vision_model_keys=["audio_tower"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_vl",
|
||||
projector_key="visual.merger",
|
||||
|
||||
@@ -78,13 +78,14 @@ def patch_processor(
|
||||
model_args: "ModelArguments",
|
||||
) -> None:
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "patch_size", get_patch_size(config, processor))
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
|
||||
if getattr(config, "vision_config", None) is not None: # visual models
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "patch_size", get_patch_size(config, processor))
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
|
||||
|
||||
|
||||
def patch_config(
|
||||
|
||||
Reference in New Issue
Block a user