Merge pull request #3287 from Ledzy/badam

[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
This commit is contained in:
hoshi-hiyouga
2024-04-16 17:32:16 +08:00
committed by GitHub
9 changed files with 215 additions and 11 deletions

View File

@@ -37,7 +37,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
@@ -82,7 +82,7 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@@ -166,7 +166,7 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

View File

@@ -17,7 +17,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
if TYPE_CHECKING:
@@ -268,8 +268,8 @@ def _prepare_model_for_training(
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")

View File

@@ -1,5 +1,6 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
@@ -100,6 +101,37 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names
def gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.