refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user