update flashattn, fix ppo save model
Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
This commit is contained in:
@@ -5,6 +5,7 @@ from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
@@ -96,7 +97,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
@@ -107,7 +107,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
|
||||
self.model.train()
|
||||
|
||||
# Run PPO step
|
||||
@@ -134,7 +133,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||
self.save_model(os.path.join(
|
||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
||||
))
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
@@ -165,8 +169,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
|
||||
|
||||
queries, responses = [], []
|
||||
@@ -294,6 +300,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user