mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[breaking change] refactor data pipeline (#6901)
* refactor data * rename file Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
@@ -4,7 +4,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -1241,14 +1241,26 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
|
||||
r"""
|
||||
Registers a multimodal plugin.
|
||||
"""
|
||||
if name in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin {name} already exists.")
|
||||
|
||||
PLUGINS[name] = plugin_class
|
||||
|
||||
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
r"""
|
||||
Gets plugin for multimodal inputs.
|
||||
"""
|
||||
if name not in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token, audio_token)
|
||||
return PLUGINS[name](image_token, video_token, audio_token)
|
||||
|
||||
Reference in New Issue
Block a user