[example] add bash usage (#7794)
This commit is contained in:
0
src/llamafactory/third_party/__init__.py
vendored
Normal file
0
src/llamafactory/third_party/__init__.py
vendored
Normal file
30
src/llamafactory/third_party/muon/muon.py
vendored
30
src/llamafactory/third_party/muon/muon.py
vendored
@@ -2,6 +2,8 @@
|
||||
#
|
||||
# This code is based on the MoonshotAI's Moonlight library.
|
||||
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
||||
# and the Keller Jordan's Muon library.
|
||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,6 +20,7 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2025 Moonshot AI
|
||||
# Copyright (c) 2024 Keller Jordan
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
@@ -36,22 +39,20 @@
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# This code snippet is a modified version adapted from the following GitHub repository:
|
||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G, steps):
|
||||
def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor":
|
||||
"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
|
||||
|
||||
We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
|
||||
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing
|
||||
the slope at zero even beyond the point where the iteration no longer converges all the way to
|
||||
one everywhere on the interval. This iteration therefore does not produce UV^T but rather something
|
||||
like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert len(G.shape) == 2
|
||||
@@ -133,7 +134,7 @@ class Muon(torch.optim.Optimizer):
|
||||
# Do not use Muon for parameters in adamw_params
|
||||
self.state[p]["use_muon"] = False
|
||||
|
||||
def adjust_lr_for_muon(self, lr, param_shape):
|
||||
def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float:
|
||||
A, B = param_shape[:2]
|
||||
# We adjust the learning rate and weight decay based on the size of the parameter matrix
|
||||
# as describted in the paper
|
||||
@@ -154,12 +155,8 @@ class Muon(torch.optim.Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
############################
|
||||
# Muon #
|
||||
############################
|
||||
|
||||
# Muon loop
|
||||
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
||||
# import pdb; pdb.set_trace()
|
||||
lr = group["lr"]
|
||||
wd = group["wd"]
|
||||
momentum = group["momentum"]
|
||||
@@ -195,10 +192,7 @@ class Muon(torch.optim.Optimizer):
|
||||
# apply update
|
||||
p.data.add_(u, alpha=-adjusted_lr)
|
||||
|
||||
############################
|
||||
# AdamW backup #
|
||||
############################
|
||||
|
||||
# Adam backup
|
||||
params = [p for p in group["params"] if not self.state[p]["use_muon"]]
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["adamw_betas"]
|
||||
|
||||
Reference in New Issue
Block a user