@@ -7,9 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
from llmtuner.train.sft.metric import ComputeMetrics
|
||||
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
||||
from llmtuner.train.utils import create_modelcard_and_push
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@@ -90,8 +91,4 @@ def run_sft(
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
Reference in New Issue
Block a user