refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -22,6 +22,7 @@ import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
@@ -31,7 +32,6 @@ from .base_engine import BaseEngine, Response
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
@@ -81,27 +81,19 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and template.image_token not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = template.image_token + messages[0]["content"]
|
||||
if image is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, [image], processor)
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
if image is not None:
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
@@ -164,8 +156,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
if image is not None:
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor
|
||||
)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
|
||||
Reference in New Issue
Block a user