2 Commits

Author SHA1 Message Date
Jewon Lee
9640f79ae5 [fix] add visual.pos_embed to Qwen3-VL visual model keys (#10139) 2026-01-27 16:33:01 +08:00
jiaqiw09
7ef19eea00 [v0] Fix reward model training safetensors saving (#10137) 2026-01-27 16:27:14 +08:00
2 changed files with 24 additions and 3 deletions

View File

@@ -356,7 +356,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@@ -365,7 +365,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@@ -374,7 +374,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)

View File

@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
else:
return loss
@override
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if state_dict is None:
state_dict = self.model.state_dict()
if self.args.save_safetensors:
from collections import defaultdict
ptrs = defaultdict(list)
for name, tensor in state_dict.items():
if isinstance(tensor, torch.Tensor):
ptrs[id(tensor)].append(name)
for names in ptrs.values():
if len(names) > 1:
names.sort()
for name in names[1:]:
state_dict.pop(name, None)
super()._save(output_dir, state_dict)
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""Save model predictions to `output_dir`.