refactor pissa, improve llamaboard

Former-commit-id: 619556e46c19718f702c97df5d570a2a4c5fb13a
This commit is contained in:
hiyouga
2024-06-28 01:04:24 +08:00
parent edc7498111
commit 46f0189e88
16 changed files with 219 additions and 216 deletions

View File

@@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
@@ -34,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -131,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
self.processor = processor
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
@@ -143,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
self.callback_handler = CallbackHandler(
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -165,11 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
self.add_callback(FixValueHeadModelCallback)
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)
self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
@@ -219,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
self.callback_handler.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
@@ -257,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
@@ -269,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
self.callback_handler.on_log(self.args, self.state, self.control, logs)
loss_meter.reset()
reward_meter.reset()
@@ -277,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
self.callback_handler.on_save(self.args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
self.callback_handler.on_train_end(self.args, self.state, self.control)
def create_optimizer(
self,
@@ -505,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
elif self.args.should_save:
self._save(output_dir)
if self.processor is not None and self.args.should_save:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)