alter rewards data type

Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent 896dbfec16
commit e9ab06678f
12 changed files with 40 additions and 50 deletions

View File

@@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.finetuning_args = finetuning_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
def ppo_train(self, max_target_length: int) -> None:
r"""
@@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]]
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step
@@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
return response[:, inputs["input_ids"].size(1):]
return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,