Feature BAdam

Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
Jonery
2024-04-15 23:15:27 +08:00
parent 276f2cb24e
commit d4d471450f
9 changed files with 195 additions and 7 deletions

View File

@@ -135,3 +135,45 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
Activates gradient checkpointing for the current model.
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
# gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
def gradient_checkpointing_func(func, *args, **kwargs):
module = func.__self__
if any([p.requires_grad for p in module.parameters()]):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return checkpoint(func, *args, **kwargs)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()