Former-commit-id: 54d5f62d29456a8d9d0c0dd3d0bbfffe48935803
This commit is contained in:
hiyouga
2024-03-20 17:59:45 +08:00
parent d8073488be
commit c7af26a9e3
12 changed files with 104 additions and 48 deletions

View File

@@ -1,4 +1,3 @@
import math
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch
@@ -19,7 +18,6 @@ if is_galore_available():
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
@@ -156,9 +154,9 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
max_steps: int,
) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore-torch")
@@ -209,12 +207,6 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param])]
@@ -231,8 +223,8 @@ def _create_galore_optimizer(
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
num_training_steps=num_training_steps * 2,
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
num_training_steps=max_steps * 2,
)
def optimizer_hook(param: "torch.Tensor"):
@@ -259,7 +251,6 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
@@ -302,12 +293,12 @@ def _create_loraplus_optimizer(
def create_custom_optimzer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
max_steps: int,
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)
return _create_loraplus_optimizer(model, training_args, finetuning_args)