diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index fe2bd5571..9bc207793 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer): else: return loss + @override + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if state_dict is None: + state_dict = self.model.state_dict() + + if self.args.save_safetensors: + from collections import defaultdict + + ptrs = defaultdict(list) + for name, tensor in state_dict.items(): + if isinstance(tensor, torch.Tensor): + ptrs[id(tensor)].append(name) + + for names in ptrs.values(): + if len(names) > 1: + names.sort() + for name in names[1:]: + state_dict.pop(name, None) + + super()._save(output_dir, state_dict) + def save_predictions(self, predict_results: "PredictionOutput") -> None: r"""Save model predictions to `output_dir`.