[data] fix qwen2.5 omni plugin (#7573)

* align key with qwen2vl

* nit && change scripts
This commit is contained in:
Kingsley
2025-04-02 21:28:52 +08:00
committed by GitHub
parent 7b9deb9410
commit d32c6c014d
4 changed files with 47 additions and 6 deletions

View File

@@ -203,7 +203,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),

View File

@@ -1405,7 +1405,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
* mm_inputs["second_per_grid_ts"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]

View File

@@ -157,7 +157,7 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if load_class is AutoModelForTextToWaveform:
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if model_args.mixture_of_depths == "convert":