upcast logits

Former-commit-id: df61660351c8af30591471807a20869a45bb055a
This commit is contained in:
hiyouga
2024-07-02 22:32:05 +08:00
parent e6ba7ef3e6
commit 579997688f
2 changed files with 2 additions and 2 deletions

View File

@@ -407,7 +407,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
values = torch.transpose(values, 0, 1)
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.to(torch.float32).detach().cpu() # use fp32 type
return rewards.float().detach() # use fp32 type
@PPODecorators.empty_device_cache()
def batched_forward_pass(