mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-04-06 17:23:08 +00:00
fix: gemma4 mm_token_type_ids padding (#10359)
This commit is contained in:
@@ -380,6 +380,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
|
||||
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
|
||||
max_len = max(len(ids) for ids in mm_token_type_ids)
|
||||
padded = []
|
||||
for ids in mm_token_type_ids:
|
||||
pad_len = max_len - len(ids)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
padded.append(ids + [0] * pad_len)
|
||||
else:
|
||||
padded.append([0] * pad_len + ids)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
bsz, seq_len = features["input_ids"].shape[:2]
|
||||
|
||||
Reference in New Issue
Block a user