Feature BAdam

Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
Jonery
2024-04-15 23:15:27 +08:00
parent 276f2cb24e
commit d4d471450f
9 changed files with 195 additions and 7 deletions

View File

@@ -17,7 +17,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
if TYPE_CHECKING:
@@ -266,8 +266,9 @@ def _prepare_model_for_training(
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
# model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")