From ffbff33af34785b5008d794260188dc0f20095f3 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Sun, 22 Mar 2026 02:28:52 +0800 Subject: [PATCH] chore: mca workflow compatible with qwen-vl series (#10303) --- src/llamafactory/train/mca/trainer.py | 60 +++++++++++++++++++++++++- src/llamafactory/train/mca/workflow.py | 50 ++++++++++++++++++--- 2 files changed, 103 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/train/mca/trainer.py b/src/llamafactory/train/mca/trainer.py index 97cc9b713..ad537588a 100644 --- a/src/llamafactory/train/mca/trainer.py +++ b/src/llamafactory/train/mca/trainer.py @@ -12,4 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO override the original trainer +from typing import Any + +import torch.nn.functional as F +from mcore_adapter.trainer import McaTrainer +from torch import Tensor +from transformers import PreTrainedTokenizerBase +from typing_extensions import override + +from ...extras.constants import IGNORE_INDEX + + +class CustomMcaTrainer(McaTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @override + def _pad_batched_inputs(self, inputs: dict[str, Tensor | Any], seq_length: int): + r"""Override to avoid padding error when handling 3d posids.""" + padding_inputs = { + k: v.tolist() if v is not None and isinstance(v, Tensor) else v + for k, v in inputs.items() + if k in self._language_input_names + } + + position_ids_3d = None + if isinstance(inputs.get("position_ids"), Tensor) and inputs["position_ids"].dim() == 3: + position_ids_3d = inputs["position_ids"] + padding_inputs.pop("position_ids", None) + + if "labels" in padding_inputs: + padding_inputs["labels"] = [ + labels + [IGNORE_INDEX] * (seq_length - len(labels)) for labels in padding_inputs["labels"] + ] + tokenizer = ( + self.processing_class + if isinstance(self.processing_class, PreTrainedTokenizerBase) + else getattr(self.processing_class, "tokenizer", self.processing_class) + ) + padding_side = getattr(tokenizer, "padding_side", "right") + padding_inputs = tokenizer.pad( + padding_inputs, + padding="max_length", + max_length=seq_length, + return_tensors="pt", + ).to(self.args.device) + inputs.update(padding_inputs) + + if position_ids_3d is not None: + current_seq_len = position_ids_3d.size(-1) + if current_seq_len < seq_length: + pad_len = seq_length - current_seq_len + if padding_side == "left": + position_ids_3d = F.pad(position_ids_3d, (pad_len, 0), value=0) + else: + position_ids_3d = F.pad(position_ids_3d, (0, pad_len), value=0) + + inputs["position_ids"] = position_ids_3d.to(self.args.device) + + return inputs diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py index 812ae5830..f99c576f9 100644 --- a/src/llamafactory/train/mca/workflow.py +++ b/src/llamafactory/train/mca/workflow.py @@ -19,6 +19,7 @@ from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional +import torch from transformers import DataCollatorForSeq2Seq from ...data import ( @@ -43,9 +44,10 @@ if not is_mcore_adapter_available(): from mcore_adapter.models import AutoConfig, AutoModel from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer -from mcore_adapter.trainer import McaTrainer from mcore_adapter.trainer.dpo_config import DPOConfig +from .trainer import CustomMcaTrainer + if TYPE_CHECKING: from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments @@ -72,7 +74,18 @@ def _data_collator_wrapper(data_collator: Any): for k in ["attention_mask", "position_ids"]: if k in feature: feature[k] = feature[k][:-1] - return data_collator(features) + + # for qwen vl series model + tmp_features = data_collator(features) + tmp_features.pop("rope_deltas", None) + position_ids = tmp_features.get("position_ids", None) + + if position_ids is not None and position_ids.dim() == 3: + if position_ids.shape[0] == 4: + position_ids = position_ids[1:] + tmp_features["position_ids"] = position_ids + + return tmp_features return wrapper @@ -103,11 +116,11 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments") params_to_freeze = [] if finetuning_args.freeze_vision_tower: params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) - if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]: + if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: params_to_freeze.extend(["vision_model.pos_embed"]) if finetuning_args.freeze_multi_modal_projector: - params_to_freeze.extend(["multi_modal_projector"]) + params_to_freeze.extend(["vision_model.merger"]) if finetuning_args.freeze_language_model: params_to_freeze.extend(["embedding", "decoder", "output_layer"]) @@ -118,6 +131,27 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments") p.requires_grad_(False) +def _build_meta_hf_model_for_collator(model_args: "ModelArguments") -> Any | None: + r"""Build a lightweight HF model on meta device for compatibility with collator.""" + from transformers import AutoConfig as HfAutoConfig + from transformers import AutoModel as HfAutoModel + from transformers import AutoModelForImageTextToText + + try: + config = HfAutoConfig.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + with torch.device("meta"): + try: + # Prefer multimodal auto class for VLMs (e.g. qwen2-vl), so get_rope_index is available. + return AutoModelForImageTextToText.from_config(config) + except Exception: + return HfAutoModel.from_config(config) + except Exception as exc: + logger.warning("Failed to build meta HF model for collator, fallback to no model. Error: %s", exc) + return None + + def run_pt( model_args: "ModelArguments", data_args: "DataArguments", @@ -143,7 +177,7 @@ def run_pt( ) data_collator = _data_collator_wrapper(data_collator) - trainer = McaTrainer( + trainer = CustomMcaTrainer( model=model, args=training_args, tokenizer=tokenizer, @@ -193,6 +227,7 @@ def run_sft( _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + collator_model = _build_meta_hf_model_for_collator(model_args) # optional freezing for qwen_vl series _freeze_model_parameters(model, finetuning_args) @@ -200,6 +235,7 @@ def run_sft( pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 data_collator = SFTDataCollatorWith4DAttentionMask( template=template, + model=collator_model, padding="max_length" if pad_to_max else "longest", max_length=data_args.cutoff_len if pad_to_max else None, pad_to_multiple_of=64, @@ -208,7 +244,7 @@ def run_sft( ) data_collator = _data_collator_wrapper(data_collator) - trainer = McaTrainer( + trainer = CustomMcaTrainer( model=model, args=training_args, tokenizer=tokenizer, @@ -247,6 +283,7 @@ def run_dpo( _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + collator_model = _build_meta_hf_model_for_collator(model_args) _freeze_model_parameters(model, finetuning_args) @@ -270,6 +307,7 @@ def run_dpo( ) data_collator = PairwiseDataCollatorWithPadding( template=template, + model=collator_model, pad_to_multiple_of=64, padding="max_length" if pad_to_max else "longest", max_length=data_args.cutoff_len if pad_to_max else None,