Former-commit-id: 6fcf2f10faf3b1614896b091591eeef96d717e64
This commit is contained in:
hiyouga
2025-01-07 06:30:44 +00:00
parent 53e41bf2c7
commit d8bd46f1bf
3 changed files with 32 additions and 10 deletions

View File

@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if TYPE_CHECKING:
@@ -330,7 +331,7 @@ def _create_badam_optimizer(
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
@@ -350,7 +351,7 @@ def _create_badam_optimizer(
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
@@ -374,7 +375,7 @@ def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
@@ -459,12 +460,33 @@ def get_batch_logps(
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab
from swanlab.integration.transformers import SwanLabCallback
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)