add adam_mini to readme

Former-commit-id: d610c6bcf8a8ba6f4236f5d11f79571b83f4fb11
This commit is contained in:
hiyouga
2024-08-09 20:02:03 +08:00
parent 7e755e9cac
commit 59cbce1a46
12 changed files with 94 additions and 34 deletions

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
@@ -365,18 +366,16 @@ def _create_badam_optimizer(
return optimizer
def _create_adammini_optimizer(
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
n_embd = model.config.hidden_size
n_head = model.config.num_attention_heads
n_query_groups = getattr(model.config, "num_key_value_heads", n_head)
print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups)
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
num_kv_head = getattr(model.config, "num_key_value_heads", None)
optimizer = Adam_mini(
named_parameters=model.named_parameters(),
@@ -384,14 +383,15 @@ def _create_adammini_optimizer(
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
model_sharding=False,
dim=n_embd,
n_heads=n_head,
n_kv_heads=n_query_groups,
model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(),
dim=hidden_size,
n_heads=num_q_head,
n_kv_heads=num_kv_head,
)
logger.info("Using Adam-mini optimizer.")
return optimizer
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
@@ -406,9 +406,9 @@ def create_custom_optimizer(
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_adammini:
return _create_adammini_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_adam_mini:
return _create_adam_mini_optimizer(model, training_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",