Former-commit-id: 067ba6e6cb4d8a1d95bba0a108f73008416a2865
This commit is contained in:
hiyouga
2024-12-19 12:16:30 +00:00
parent 0a465fc3ca
commit 0385c60177
6 changed files with 22 additions and 16 deletions

View File

@@ -91,7 +91,7 @@ def run_sft(
)
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
@@ -130,7 +130,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)