fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval() Former-commit-id: 4be013f18ea6a35b5a11db98db5f0670ffb41619
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
device_type = unwrapped_model.pretrained_model.device.type
|
||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
|
||||
Reference in New Issue
Block a user