add test cases

Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
This commit is contained in:
hiyouga
2024-06-15 04:05:54 +08:00
parent f4f315fd11
commit 3ff9b87012
9 changed files with 184 additions and 34 deletions

View File

@@ -135,8 +135,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":