update ppo trainer

Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
This commit is contained in:
hiyouga
2023-08-02 18:46:41 +08:00
parent 1dfb28b362
commit e4d0b8ee6e
2 changed files with 46 additions and 44 deletions

View File

@@ -47,7 +47,6 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()