improve lora+ impl.
Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
@@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||
from ...train.sft.metric import ComputeMetrics
|
||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
|
||||
from ..utils import create_custom_optimzer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -51,8 +51,6 @@ def run_sft(
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
if finetuning_args.lora_lr_ratio:
|
||||
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
||||
Reference in New Issue
Block a user