mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-05 09:33:09 +00:00
Compare commits
2 Commits
f9f11dcb97
...
9640f79ae5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9640f79ae5 | ||
|
|
7ef19eea00 |
@@ -356,7 +356,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl",
|
model_type="qwen3_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
@@ -365,7 +365,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl_moe",
|
model_type="qwen3_vl_moe",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
@@ -374,7 +374,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_omni_moe_thinker",
|
model_type="qwen3_omni_moe_thinker",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
|
if self.args.save_safetensors:
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for name, tensor in state_dict.items():
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
ptrs[id(tensor)].append(name)
|
||||||
|
|
||||||
|
for names in ptrs.values():
|
||||||
|
if len(names) > 1:
|
||||||
|
names.sort()
|
||||||
|
for name in names[1:]:
|
||||||
|
state_dict.pop(name, None)
|
||||||
|
|
||||||
|
super()._save(output_dir, state_dict)
|
||||||
|
|
||||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
r"""Save model predictions to `output_dir`.
|
r"""Save model predictions to `output_dir`.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user