add test cases
Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
This commit is contained in:
@@ -135,8 +135,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
||||
|
||||
device_type = unwrapped_model.pretrained_model.device.type
|
||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
|
||||
Reference in New Issue
Block a user