fix import bug

Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
This commit is contained in:
hiyouga
2023-11-16 02:27:03 +08:00
parent 7a3a0144a5
commit f81a8a5e5c
6 changed files with 91 additions and 84 deletions

View File

@@ -6,10 +6,10 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import create_ref_model, generate_model_card, load_model_and_tokenizer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
@@ -18,9 +18,6 @@ if TYPE_CHECKING:
from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -74,7 +71,6 @@ def run_dpo(
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)