add logits processor
Former-commit-id: f6f4b1554ae1e8849b437d705ffa34ce7ebd56bb
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
"temperature": temperature,
|
||||
"num_beams": 1,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": 1.0
|
||||
"repetition_penalty": 1.5,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user