@@ -17,6 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -330,7 +331,7 @@ def _create_badam_optimizer(
|
||||
]
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
from badam import BlockOptimizer # type: ignore
|
||||
|
||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
optimizer = BlockOptimizer(
|
||||
@@ -350,7 +351,7 @@ def _create_badam_optimizer(
|
||||
)
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
from badam import BlockOptimizerRatio # type: ignore
|
||||
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
@@ -374,7 +375,7 @@ def _create_adam_mini_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from adam_mini import Adam_mini
|
||||
from adam_mini import Adam_mini # type: ignore
|
||||
|
||||
hidden_size = getattr(model.config, "hidden_size", None)
|
||||
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||
@@ -459,12 +460,33 @@ def get_batch_logps(
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
|
||||
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
if clone:
|
||||
return tensors.detach().clone()
|
||||
else:
|
||||
return tensors.detach()
|
||||
else:
|
||||
return tensors
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
import swanlab
|
||||
from swanlab.integration.transformers import SwanLabCallback
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
if finetuning_args.swanlab_api_key is not None:
|
||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||
|
||||
Reference in New Issue
Block a user