[assets] update readme (#7644)
This commit is contained in:
@@ -19,7 +19,7 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
transformers>=4.41.2,<=4.51.1,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
datasets>=2.16.0,<=3.4.1
|
||||
accelerate>=0.34.0,<=1.5.2
|
||||
peft>=0.14.0,<=0.15.0
|
||||
|
||||
@@ -298,8 +298,9 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "cross_attention_mask" in kl_batch: # for mllama inputs.
|
||||
if "cross_attention_mask" in kl_batch: # for mllama inputs
|
||||
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
|
||||
|
||||
if "token_type_ids" in kl_batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("transformers>=4.41.2,<=4.51.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.4.1")
|
||||
check_version("accelerate>=0.34.0,<=1.5.2")
|
||||
check_version("peft>=0.14.0,<=0.15.0")
|
||||
|
||||
@@ -147,6 +147,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "image_sizes" in batch:
|
||||
model_inputs["image_sizes"] = batch["image_sizes"]
|
||||
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user