[trainer] fix llama3.2 vision kto train (#6904)

Former-commit-id: 1563e89adc8988fc6e4250634a3f1e385979b0e5
This commit is contained in:
marko1616
2025-02-12 19:09:14 +08:00
committed by GitHub
parent 2581cc844b
commit 0c0cdc26bc
2 changed files with 11 additions and 0 deletions

View File

@@ -285,6 +285,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "cross_attention_mask" in kl_batch: # for mllama inputs.
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]