@@ -5,6 +5,7 @@ from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||
@@ -49,6 +50,8 @@ def run_rm(
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
Reference in New Issue
Block a user