[model] support audio (#6701)

* support qwen2_audio

* improve code

* lint

* fix

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
Zhangchi Feng
2025-02-05 04:59:09 +08:00
committed by GitHub
parent 9feb78e7b4
commit 8f401e37f8
35 changed files with 675 additions and 213 deletions

View File

@@ -22,6 +22,8 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>")
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
@@ -58,6 +60,8 @@ METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
@@ -89,8 +93,6 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
@@ -101,14 +103,16 @@ class DownloadSource(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
template: Optional[str] = None,
vision: bool = False,
multimodal: bool = False,
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or vision):
if template is not None and (
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
):
DEFAULT_TEMPLATE[name] = template
if vision:
VISION_MODELS.add(name)
if multimodal:
MULTIMODAL_SUPPORTED_MODELS.add(name)
register_model_group(
@@ -1030,7 +1034,7 @@ register_model_group(
},
},
template="mllama",
vision=True,
multimodal=True,
)
@@ -1046,7 +1050,7 @@ register_model_group(
},
},
template="llava",
vision=True,
multimodal=True,
)
@@ -1062,7 +1066,7 @@ register_model_group(
},
},
template="llava_next",
vision=True,
multimodal=True,
)
@@ -1074,7 +1078,7 @@ register_model_group(
},
},
template="llava_next_mistral",
vision=True,
multimodal=True,
)
@@ -1086,7 +1090,7 @@ register_model_group(
},
},
template="llava_next_llama3",
vision=True,
multimodal=True,
)
@@ -1098,7 +1102,7 @@ register_model_group(
},
},
template="llava_next_yi",
vision=True,
multimodal=True,
)
@@ -1114,7 +1118,7 @@ register_model_group(
},
},
template="llava_next_qwen",
vision=True,
multimodal=True,
)
@@ -1130,7 +1134,7 @@ register_model_group(
},
},
template="llava_next_video",
vision=True,
multimodal=True,
)
@@ -1142,7 +1146,7 @@ register_model_group(
},
},
template="llava_next_video_mistral",
vision=True,
multimodal=True,
)
@@ -1157,7 +1161,7 @@ register_model_group(
},
},
template="llava_next_video_yi",
vision=True,
multimodal=True,
)
@@ -1207,7 +1211,7 @@ register_model_group(
},
},
template="minicpm_v",
vision=True,
multimodal=True,
)
@@ -1219,7 +1223,7 @@ register_model_group(
},
},
template="minicpm_v",
vision=True,
multimodal=True,
)
@@ -1424,7 +1428,7 @@ register_model_group(
},
},
template="paligemma",
vision=True,
multimodal=True,
)
@@ -1468,7 +1472,7 @@ register_model_group(
},
},
template="paligemma",
vision=True,
multimodal=True,
)
@@ -1551,7 +1555,7 @@ register_model_group(
}
},
template="pixtral",
vision=True,
multimodal=True,
)
@@ -2134,6 +2138,22 @@ register_model_group(
)
register_model_group(
models={
"Qwen2-Audio-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B",
},
"Qwen2-Audio-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B-Instruct",
},
},
template="qwen2_audio",
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B-Instruct": {
@@ -2204,7 +2224,7 @@ register_model_group(
},
},
template="qwen2_vl",
vision=True,
multimodal=True,
)
@@ -2329,7 +2349,7 @@ register_model_group(
},
},
template="video_llava",
vision=True,
multimodal=True,
)
@@ -2556,7 +2576,7 @@ register_model_group(
},
},
template="yi_vl",
vision=True,
multimodal=True,
)