Former-commit-id: 54d5f62d29456a8d9d0c0dd3d0bbfffe48935803
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -19,7 +18,6 @@ if is_galore_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -156,9 +154,9 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
|
||||
def _create_galore_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
max_steps: int,
|
||||
) -> "torch.optim.Optimizer":
|
||||
require_version("galore_torch", "To fix: pip install galore-torch")
|
||||
|
||||
@@ -209,12 +207,6 @@ def _create_galore_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
@@ -231,8 +223,8 @@ def _create_galore_optimizer(
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
|
||||
num_training_steps=max_steps * 2,
|
||||
)
|
||||
|
||||
def optimizer_hook(param: "torch.Tensor"):
|
||||
@@ -259,7 +251,6 @@ def _create_galore_optimizer(
|
||||
|
||||
def _create_loraplus_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
@@ -302,12 +293,12 @@ def _create_loraplus_optimizer(
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
max_steps: int,
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_galore:
|
||||
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
|
||||
|
||||
if finetuning_args.loraplus_lr_ratio is not None:
|
||||
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)
|
||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
Reference in New Issue
Block a user