add adam_mini to readme
Former-commit-id: d610c6bcf8a8ba6f4236f5d11f79571b83f4fb11
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
@@ -365,18 +366,16 @@ def _create_badam_optimizer(
|
||||
|
||||
return optimizer
|
||||
|
||||
def _create_adammini_optimizer(
|
||||
|
||||
def _create_adam_mini_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from adam_mini import Adam_mini
|
||||
|
||||
n_embd = model.config.hidden_size
|
||||
n_head = model.config.num_attention_heads
|
||||
n_query_groups = getattr(model.config, "num_key_value_heads", n_head)
|
||||
|
||||
print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups)
|
||||
hidden_size = getattr(model.config, "hidden_size", None)
|
||||
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||
num_kv_head = getattr(model.config, "num_key_value_heads", None)
|
||||
|
||||
optimizer = Adam_mini(
|
||||
named_parameters=model.named_parameters(),
|
||||
@@ -384,14 +383,15 @@ def _create_adammini_optimizer(
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
eps=training_args.adam_epsilon,
|
||||
weight_decay=training_args.weight_decay,
|
||||
model_sharding=False,
|
||||
dim=n_embd,
|
||||
n_heads=n_head,
|
||||
n_kv_heads=n_query_groups,
|
||||
model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(),
|
||||
dim=hidden_size,
|
||||
n_heads=num_q_head,
|
||||
n_kv_heads=num_kv_head,
|
||||
)
|
||||
|
||||
logger.info("Using Adam-mini optimizer.")
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_custom_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
@@ -406,9 +406,9 @@ def create_custom_optimizer(
|
||||
if finetuning_args.use_badam:
|
||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_adammini:
|
||||
return _create_adammini_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_adam_mini:
|
||||
return _create_adam_mini_optimizer(model, training_args)
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
|
||||
Reference in New Issue
Block a user