[example] add bash usage (#7794)

This commit is contained in:
hoshi-hiyouga
2025-04-22 00:25:51 +08:00
committed by GitHub
parent 12ada72ed4
commit b07628dea5
13 changed files with 184 additions and 98 deletions

View File

@@ -489,16 +489,14 @@ def _create_adam_mini_optimizer(
logger.info_rank0("Using Adam-mini optimizer.")
return optimizer
def _create_muon_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from llamafactory.third_party.muon import Muon # type: ignore
# Separate parameters for Muon (2D parameters) and AdamW (others)
muon_params = []
adamw_params = []
from ..third_party.muon import Muon
muon_params, adamw_params = [], []
for name, param in model.named_parameters():
if param.requires_grad:
# Use Muon for 2D parameters that aren't embeddings or heads
@@ -506,34 +504,26 @@ def _create_muon_optimizer(
muon_params.append(param)
else:
adamw_params.append(param)
# Get optimizer settings from training_args
ns_steps = getattr(training_args, "ns_steps", 5)
# Create Muon optimizer
optimizer = Muon(
lr=training_args.learning_rate,
wd=training_args.weight_decay,
muon_params=muon_params,
momentum=0.95, # default momentum for Muon
nesterov=True, # default nesterov for Muon
ns_steps=ns_steps,
adamw_params=adamw_params,
adamw_betas=(training_args.adam_beta1, training_args.adam_beta2),
adamw_eps=training_args.adam_epsilon,
)
logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.")
logger.info_rank0(
f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params."
)
return optimizer
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_muon:
return _create_muon_optimizer(model, training_args)
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)
@@ -549,6 +539,9 @@ def create_custom_optimizer(
if finetuning_args.use_adam_mini:
return _create_adam_mini_optimizer(model, training_args)
if finetuning_args.use_muon:
return _create_muon_optimizer(model, training_args)
def create_custom_scheduler(
training_args: "TrainingArguments",