release v0.7.0
Former-commit-id: 45bb89cb4d26a6b3fb5360bc90ab950738fe4920
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_pillow_available
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
@@ -20,12 +26,11 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None:
|
||||
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
# process visual inputs (currently only supports a single image)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||
if "pixel_values" not in model_inputs:
|
||||
model_inputs["pixel_values"] = []
|
||||
model_inputs["pixel_values"].append(pixel_values)
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
@@ -66,11 +71,17 @@ def preprocess_supervised_dataset(
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
@@ -100,8 +111,8 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -161,11 +172,17 @@ def preprocess_unsupervised_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
@@ -186,8 +203,8 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -201,10 +218,17 @@ def preprocess_pairwise_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
@@ -231,8 +255,8 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user