mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
@@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@@ -45,8 +45,7 @@ def run_sft(
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user