add eval acc

Former-commit-id: 7ffde76fbfb6192e3aac31ccc098f31ce89181ae
This commit is contained in:
hiyouga
2024-07-01 03:51:20 +08:00
parent 38c94d2e9c
commit 884b49e662
3 changed files with 31 additions and 17 deletions

View File

@@ -25,7 +25,7 @@ from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
@@ -72,7 +72,8 @@ def run_sft(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)
@@ -91,7 +92,7 @@ def run_sft(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation
if training_args.do_eval: