DeepSpeed ZeRO3 has inflight param error when calling model.eval()


Former-commit-id: 4be013f18ea6a35b5a11db98db5f0670ffb41619
This commit is contained in:
hiyouga
2024-06-13 02:25:50 +08:00
parent 0a75224f62
commit 103a507b39
4 changed files with 12 additions and 17 deletions

View File

@@ -1,3 +1,4 @@
import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@@ -10,7 +11,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -61,6 +62,8 @@ class CustomDPOTrainer(DPOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
warnings.simplefilter("ignore") # remove gc warnings on ref model
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@@ -176,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None:
ref_model = model
ref_context = get_ref_context(self.accelerator, model)
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()