remove PeftTrainer

Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
hiyouga
2023-09-10 22:23:23 +08:00
parent 332d7bbd56
commit a09a7b650d
17 changed files with 75 additions and 259 deletions

View File

@@ -4,10 +4,10 @@ import torch
import numpy as np
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer):
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""

View File

@@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -45,8 +45,7 @@ def run_sft(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args,
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,