refactor evaluation, upgrade trl to 074

Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
This commit is contained in:
hiyouga
2023-11-13 22:20:35 +08:00
parent 989eccd286
commit 64fc9ba678
21 changed files with 341 additions and 247 deletions

View File

@@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",