Former-commit-id: 9acd5a2b678cd07f8e3b48eca76c4cbacb559e37
This commit is contained in:
hiyouga
2024-01-11 17:04:13 +08:00
parent 64246d42d2
commit 73cab9d9d4
4 changed files with 75 additions and 71 deletions

View File

@@ -5,6 +5,7 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.misc import fix_valuehead_checkpoint
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
@@ -49,6 +50,8 @@ def run_rm(
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()