add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -3,6 +3,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available
@@ -209,6 +210,7 @@ class BasePlugin:
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
@@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
@@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
@@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
class Qwen2vlPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],