tiny fix
Former-commit-id: 830511a6d0216da99520aee8b3a753d347a71fa9
This commit is contained in:
@@ -68,7 +68,7 @@ class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None
|
||||
image_grid_thw = None # TODO: better handle various VLMs
|
||||
if "image_grid_thw" in features[0]:
|
||||
image_grid_thw_list = [
|
||||
torch.Tensor(feature["image_grid_thw"]).long()
|
||||
|
||||
@@ -74,6 +74,9 @@ class BasePlugin:
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@@ -83,6 +86,9 @@ class BasePlugin:
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
@@ -91,6 +97,9 @@ class BasePlugin:
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def process_model_inputs(
|
||||
@@ -100,6 +109,9 @@ class BasePlugin:
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
r"""
|
||||
Appends multimodal inputs to model inputs for VLMs.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def preprocess_feedback_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
|
||||
@@ -70,7 +70,7 @@ def preprocess_pairwise_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||
|
||||
@@ -62,7 +62,7 @@ def preprocess_unsupervised_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
|
||||
Reference in New Issue
Block a user