upgrade peft, fix #1088 #1411

Former-commit-id: aa7d104f8e050d12cb8f585bc8a52c850995500f
This commit is contained in:
hiyouga
2023-11-07 16:13:36 +08:00
parent 37a0d62a82
commit 2eb65d21ac
15 changed files with 133 additions and 99 deletions

View File

@@ -8,7 +8,7 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
@@ -52,13 +52,18 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card())
else:
trainer.create_model_card(**generate_model_card())
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")