Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
@@ -1,4 +1,19 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -10,7 +25,7 @@ from ...extras.misc import get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import ComputeMetrics
|
||||
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -56,7 +71,8 @@ def run_sft(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
@@ -75,7 +91,7 @@ def run_sft(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
@@ -92,7 +108,7 @@ def run_sft(
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
trainer.save_predictions(dataset, predict_results)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
Reference in New Issue
Block a user