Former-commit-id: be76f6cbe5143f781b6b39603b80392253b3080a
This commit is contained in:
hiyouga
2023-09-08 20:22:18 +08:00
parent 612d97db6f
commit e70b3e8947
3 changed files with 5 additions and 3 deletions

View File

@@ -99,6 +99,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
@@ -108,6 +110,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
self.model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
@@ -157,10 +161,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273