Former-commit-id: 263b2b24c8a649b51fa5ae768a24e67def8e0e96
This commit is contained in:
hiyouga
2023-11-19 14:15:47 +08:00
parent 3d1ee27ccd
commit 6889f044fb
8 changed files with 35 additions and 31 deletions

View File

@@ -6,10 +6,11 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
from llmtuner.train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -68,8 +69,4 @@ def run_rm(
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)