update flashattn, fix ppo save model

Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
This commit is contained in:
hiyouga
2023-09-11 17:25:36 +08:00
parent a09a7b650d
commit 42e0b30476
5 changed files with 105 additions and 518 deletions

View File

@@ -5,6 +5,7 @@ from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
@@ -96,7 +97,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
self.model.eval()
# Get inputs
@@ -107,7 +107,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
self.model.train()
# Run PPO step
@@ -134,7 +133,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
self.save_model(os.path.join(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
))
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
@@ -165,8 +169,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
input_ids = batch["input_ids"]
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
queries, responses = [], []
@@ -294,6 +300,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
"""
if self.args.should_save:
self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)