mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
[v0] Fix reward model training safetensors saving (#10137)
This commit is contained in:
@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
|
||||
else:
|
||||
return loss
|
||||
|
||||
@override
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.args.save_safetensors:
|
||||
from collections import defaultdict
|
||||
|
||||
ptrs = defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
ptrs[id(tensor)].append(name)
|
||||
|
||||
for names in ptrs.values():
|
||||
if len(names) > 1:
|
||||
names.sort()
|
||||
for name in names[1:]:
|
||||
state_dict.pop(name, None)
|
||||
|
||||
super()._save(output_dir, state_dict)
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user