Former-commit-id: d86455f685fa531e651333e00b4fe54d895cf2e4
This commit is contained in:
hiyouga
2024-01-09 18:31:27 +08:00
parent 89f4ae51f9
commit 4d6669c268
9 changed files with 78 additions and 50 deletions

View File

@@ -8,11 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
@@ -60,7 +61,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.state, "deepspeed_plugin"
)
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -369,9 +370,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
" use zero_to_fp32.py to recover weights"
)
self._save(output_dir, state_dict={})
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
self.model.save_checkpoint(output_dir) # wrapped model
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir)