DeepSpeed ZeRO3 has inflight param error when calling model.eval()


Former-commit-id: 4be013f18ea6a35b5a11db98db5f0670ffb41619
This commit is contained in:
hiyouga
2024-06-13 02:25:50 +08:00
parent 0a75224f62
commit 103a507b39
4 changed files with 12 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
import math
import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled: