mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-04-07 05:33:09 +00:00
fix: gemma4 mm_token_type_ids padding (#10359)
This commit is contained in:
@@ -380,6 +380,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
for i, feature in enumerate(features):
|
for i, feature in enumerate(features):
|
||||||
feature["token_type_ids"] = token_type_ids[i]
|
feature["token_type_ids"] = token_type_ids[i]
|
||||||
|
|
||||||
|
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
|
||||||
|
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
|
||||||
|
max_len = max(len(ids) for ids in mm_token_type_ids)
|
||||||
|
padded = []
|
||||||
|
for ids in mm_token_type_ids:
|
||||||
|
pad_len = max_len - len(ids)
|
||||||
|
if self.tokenizer.padding_side == "right":
|
||||||
|
padded.append(ids + [0] * pad_len)
|
||||||
|
else:
|
||||||
|
padded.append([0] * pad_len + ids)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
|
||||||
|
|
||||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||||
|
|
||||||
bsz, seq_len = features["input_ids"].shape[:2]
|
bsz, seq_len = features["input_ids"].shape[:2]
|
||||||
|
|||||||
Reference in New Issue
Block a user