fix import bug

Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
This commit is contained in:
hiyouga
2023-11-16 02:27:03 +08:00
parent 7a3a0144a5
commit f81a8a5e5c
6 changed files with 91 additions and 84 deletions

View File

@@ -9,9 +9,9 @@ from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import create_ref_model, create_reward_model, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
@@ -19,9 +19,6 @@ if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
def run_ppo(
model_args: "ModelArguments",
data_args: "DataArguments",