[trainer] Add Muon Optimizer (#7749)

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Juanxi Tian
2025-04-21 23:38:37 +08:00
committed by GitHub
parent 416853dd25
commit 12ada72ed4
10 changed files with 371 additions and 25 deletions

View File

@@ -489,12 +489,51 @@ def _create_adam_mini_optimizer(
logger.info_rank0("Using Adam-mini optimizer.")
return optimizer
def _create_muon_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from llamafactory.third_party.muon import Muon # type: ignore
# Separate parameters for Muon (2D parameters) and AdamW (others)
muon_params = []
adamw_params = []
for name, param in model.named_parameters():
if param.requires_grad:
# Use Muon for 2D parameters that aren't embeddings or heads
if param.ndim == 2 and "embed" not in name and "lm_head" not in name:
muon_params.append(param)
else:
adamw_params.append(param)
# Get optimizer settings from training_args
ns_steps = getattr(training_args, "ns_steps", 5)
# Create Muon optimizer
optimizer = Muon(
lr=training_args.learning_rate,
wd=training_args.weight_decay,
muon_params=muon_params,
momentum=0.95, # default momentum for Muon
nesterov=True, # default nesterov for Muon
ns_steps=ns_steps,
adamw_params=adamw_params,
adamw_betas=(training_args.adam_beta1, training_args.adam_beta2),
adamw_eps=training_args.adam_epsilon,
)
logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.")
return optimizer
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_muon:
return _create_muon_optimizer(model, training_args)
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)