fix generation bug #532

Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
This commit is contained in:
hiyouga
2023-08-17 22:21:34 +08:00
parent e993e717a5
commit fa1893b59c
5 changed files with 15 additions and 46 deletions

View File

@@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
@@ -52,10 +52,9 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
# Training
if training_args.do_train: