add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: a730f89a972f1a9d37c718c716f199cb8d4903b2
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from .collator import PairwiseDataCollatorWithPadding
|
||||
from .loader import get_dataset
|
||||
from .loader import get_dataset, get_mm_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"get_mm_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
"templates",
|
||||
|
||||
@@ -13,23 +13,21 @@ from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum, merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import Seq2SeqTrainingArguments, AutoProcessor
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
@@ -115,11 +113,11 @@ def load_single_dataset(
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
@@ -177,3 +175,33 @@ def get_dataset(
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_mm_dataset(
|
||||
processor: "AutoProcessor",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
tokenizer = processor.tokenizer
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
all_datasets.append(load_dataset("json", data_files=local_path)['train'])
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user