video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
This commit is contained in:
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
@@ -30,11 +30,11 @@ from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
from ..data.mm_plugin import ImageInput, VideoInput
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@@ -78,20 +78,30 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, [image], processor)
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||
if image is not None:
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
||||
)
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
@@ -154,13 +164,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
|
||||
)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@@ -175,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -210,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -267,7 +276,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -284,6 +294,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -295,7 +306,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -312,6 +324,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
||||
Reference in New Issue
Block a user