fix import error

Former-commit-id: b3207a974a45038591b8cbbcf20d1ca1142d6679
This commit is contained in:
hiyouga
2023-08-23 20:45:03 +08:00
parent eb9ac9ee1f
commit 180a05a446
3 changed files with 6 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
@@ -12,7 +13,7 @@ from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments