Former-commit-id: 5e710c4ac331f3400534d33b2646c4108c898d98
This commit is contained in:
hiyouga
2024-04-18 15:34:45 +08:00
parent 619264c854
commit 9e1bd6420d
2 changed files with 3 additions and 2 deletions

View File

@@ -132,8 +132,9 @@ def gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else:
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)