Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
@@ -37,7 +37,7 @@ def init_adapter(
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
@@ -82,7 +82,7 @@ def init_adapter(
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
@@ -162,7 +162,7 @@ def init_adapter(
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -266,8 +266,9 @@ def _prepare_model_for_training(
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.enable_input_require_grads()
|
||||
# model.enable_input_require_grads()
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
|
||||
@@ -135,3 +135,45 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok
|
||||
model.__class__.register_for_auto_class()
|
||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
"""
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
||||
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
|
||||
Args:
|
||||
gradient_checkpointing_kwargs (dict, *optional*):
|
||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||
"""
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {}
|
||||
|
||||
# gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
def gradient_checkpointing_func(func, *args, **kwargs):
|
||||
module = func.__self__
|
||||
|
||||
if any([p.requires_grad for p in module.parameters()]):
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
|
||||
return checkpoint(func, *args, **kwargs)
|
||||
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
Reference in New Issue
Block a user