mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
[v0] Fix reward model training safetensors saving (#10137)
This commit is contained in:
@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
|
if self.args.save_safetensors:
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for name, tensor in state_dict.items():
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
ptrs[id(tensor)].append(name)
|
||||||
|
|
||||||
|
for names in ptrs.values():
|
||||||
|
if len(names) > 1:
|
||||||
|
names.sort()
|
||||||
|
for name in names[1:]:
|
||||||
|
state_dict.pop(name, None)
|
||||||
|
|
||||||
|
super()._save(output_dir, state_dict)
|
||||||
|
|
||||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
r"""Save model predictions to `output_dir`.
|
r"""Save model predictions to `output_dir`.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user