Former-commit-id: 263b2b24c8a649b51fa5ae768a24e67def8e0e96
This commit is contained in:
hiyouga
2023-11-19 14:15:47 +08:00
parent 3d1ee27ccd
commit 6889f044fb
8 changed files with 35 additions and 31 deletions

View File

@@ -7,9 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -90,8 +91,4 @@ def run_sft(
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)