support mllm hf inference

Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
This commit is contained in:
hiyouga
2024-04-26 05:34:58 +08:00
parent 10a6c395bb
commit 23b881bff1
23 changed files with 128 additions and 49 deletions

View File

@@ -8,7 +8,7 @@ from .utils import Role
if TYPE_CHECKING:
from PIL import Image
from PIL.Image import Image
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -271,7 +271,11 @@ def get_preprocess_and_print_func(
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing: