mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
This commit is contained in:
@@ -190,10 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": features["attention_mask"],
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs:
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(
|
||||
feature_attention_mask, dim=1
|
||||
) # FIXME need to get video image lengths
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
||||
# avoid conflict
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
|
||||
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = (
|
||||
new_position_ids.clone(),
|
||||
rope_deltas - delta0,
|
||||
) # avoid inplace operation FIXME
|
||||
else: # for qwen2vl
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
||||
Reference in New Issue
Block a user