@@ -99,6 +99,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||
@@ -108,6 +110,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
|
||||
self.model.train()
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
@@ -157,10 +161,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
|
||||
Reference in New Issue
Block a user