tiny fix
Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
import inspect
|
||||
from functools import partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -38,48 +38,51 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(
|
||||
ctx: "torch.autograd.Function",
|
||||
forward_function: "torch.Module",
|
||||
hidden_states: "torch.Tensor",
|
||||
*args: Union["torch.Tensor", Any],
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(
|
||||
ctx: "torch.autograd.Function",
|
||||
forward_function: "torch.Module",
|
||||
hidden_states: "torch.Tensor",
|
||||
*args: Union["torch.Tensor", Any],
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
|
||||
return UnslothGradientCheckpointing.apply
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
|
||||
@wraps(gradient_checkpointing_func)
|
||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
@@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
@@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
if use_unsloth_gc:
|
||||
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
|
||||
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
|
||||
else:
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user