fix generation bug #532
Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
This commit is contained in:
@@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
@@ -52,10 +52,9 @@ def run_sft(
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
Reference in New Issue
Block a user