improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 73f4513c84
commit 46f99ff277
12 changed files with 165 additions and 169 deletions

View File

@@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
from ..utils import create_custom_optimzer
if TYPE_CHECKING:
@@ -51,8 +51,6 @@ def run_sft(
# Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,