[data] fix qwen2.5 omni plugin (#7573)
* align key with qwen2vl * nit && change scripts
This commit is contained in:
@@ -203,7 +203,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
||||
# avoid conflict
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
|
||||
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = (
|
||||
new_position_ids.clone(),
|
||||
|
||||
@@ -1405,7 +1405,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* mm_inputs["second_per_grid_ts"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
|
||||
@@ -157,7 +157,7 @@ def load_model(
|
||||
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
if load_class is AutoModelForTextToWaveform:
|
||||
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
|
||||
model = model.thinker # use part of Omni model
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
|
||||
Reference in New Issue
Block a user