improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 73f4513c84
commit 46f99ff277
12 changed files with 165 additions and 169 deletions

View File

@@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
from ..utils import create_custom_optimzer
if TYPE_CHECKING:
@@ -51,8 +51,6 @@ def run_sft(
# Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,

View File

@@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None):
model_args, data_args, finetuning_args, _ = get_infer_args(args)
model_args.device_map = {"": "cpu"}
if model_args.export_dir is None:
raise ValueError("Please specify `export_dir`.")
raise ValueError("Please specify `export_dir` to save model.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
@@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", torch.float16)
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
model = model.to(output_dtype)
setattr(model.config, "torch_dtype", output_dtype)
model.save_pretrained(
save_directory=model_args.export_dir,

View File

@@ -1,15 +1,17 @@
import math
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from transformers.trainer import Trainer
import torch
from torch import nn
from transformers import Trainer
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params
from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params
if is_galore_available():
@@ -29,9 +31,10 @@ logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, *args, **kwargs):
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None:
dummy_tensor = torch.randn(1, 1)
super().__init__([dummy_tensor], {"lr": 1e-3})
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
@@ -142,59 +145,33 @@ def create_reward_model(
return reward_model
def create_custom_optimzer(
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def _create_galore_optimizer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return None
) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
galore_params: List[torch.nn.Parameter] = []
galore_targets = finetuning_args.galore_target.split(",")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
galore_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
galore_params.append(param)
id_galore_params = {id(param) for param in galore_params}
trainable_params = filter(lambda param: param.requires_grad, model.parameters())
non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
}
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
"optim_bits": 8,
"is_paged": "paged" in training_args.optim,
}
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
optim_kwargs = {
"lr": training_args.learning_rate,
"weight_decay": training_args.weight_decay,
}
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
galore_kwargs = {
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
@@ -202,6 +179,30 @@ def create_custom_optimzer(
"proj_type": finetuning_args.galore_proj_type,
}
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_galore_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
@@ -213,15 +214,18 @@ def create_custom_optimzer(
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in non_galore_params:
for param in nodecay_params:
param_groups = [dict(params=[param])]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params:
param_groups = [dict(params=[param], **galore_kwargs)]
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in non_galore_params + galore_params:
for param in trainable_params:
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
@@ -235,99 +239,72 @@ def create_custom_optimzer(
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in non_galore_params + galore_params:
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer()
optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result
else:
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
param_groups = [
dict(params=nodecay_params),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer
def optimizer_group_callback(model, lora_lr_ratio, **defaults):
"lora plus"
params = []
names = set()
def _create_loraplus_optimizer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if finetuning_args.finetuning_type != "lora":
raise ValueError("You should use LoRA tuning to activate LoRA+.")
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
decay_args = {"weight_decay": training_args.weight_decay}
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
"embedding": [],
}
for name, param in model.named_parameters():
if "default" in name and ('lora_B' in name or
'lora_embedding_B' in name):
params.append(param)
names.add(name)
if params:
assert 'lr' in defaults
return names, {
'params': params,
'lr': defaults['lr'] * lora_lr_ratio,
}
return None, None
def create_lora_plus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.lora_lr_ratio is None:
return None
all_param_names = set()
param_groups = []
param_names, param_group = optimizer_group_callback(
model, lora_lr_ratio=finetuning_args.lora_lr_ratio,
lr=training_args.learning_rate,
weight_decay=training_args.weight_decay)
if param_names and all_param_names & param_names:
raise ValueError(
'Cannot set one parameter to different param groups')
if param_names and param_group:
all_param_names.update(param_names)
param_groups.append(param_group)
opt_model = model
decay_parameters = Trainer.get_decay_parameter_names(None, opt_model)
param_groups.extend([
{
'params': [
p for n, p in opt_model.named_parameters()
if (n in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
training_args.weight_decay,
},
{
'params': [
p for n, p in opt_model.named_parameters()
if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
0.0,
},
])
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
optimizer = optimizer_cls(param_groups, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({
p.data_ptr(): p.numel()
for p in module.parameters()
}.values())
logger.info(
f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(
module, 'weight', {'optim_bits': 32})
logger.debug(
f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if param.requires_grad:
if "lora_embedding_B" in name:
param_dict["embedding"].append(param)
elif "lora_B" in name or param.ndim == 1:
if name in decay_param_names:
param_dict["lora_b"].append(param)
else:
param_dict["lora_b_nodecay"].append(param)
else:
param_dict["lora_a"].append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=param_dict["lora_a"], **decay_args),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
]
optimizer = optim_class(param_groups, **optim_kwargs)
return optimizer
def create_custom_optimzer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)