mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
Compare commits
2 Commits
f9f11dcb97
...
9640f79ae5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9640f79ae5 | ||
|
|
7ef19eea00 |
@@ -356,7 +356,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
@@ -365,7 +365,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl_moe",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
@@ -374,7 +374,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_omni_moe_thinker",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
|
||||
else:
|
||||
return loss
|
||||
|
||||
@override
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.args.save_safetensors:
|
||||
from collections import defaultdict
|
||||
|
||||
ptrs = defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
ptrs[id(tensor)].append(name)
|
||||
|
||||
for names in ptrs.values():
|
||||
if len(names) > 1:
|
||||
names.sort()
|
||||
for name in names[1:]:
|
||||
state_dict.pop(name, None)
|
||||
|
||||
super()._save(output_dir, state_dict)
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user