support RM metrics, add generating Args
Former-commit-id: c461c6190bc124e98dde7f3cf96a59ce40b26fb0
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers.utils.versions import require_version
|
||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
|
||||
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
@@ -87,9 +87,9 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||
"do_sample": True,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"num_beams": 1,
|
||||
"num_beams": generating_args.infer_num_beams,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": 1.0,
|
||||
"repetition_penalty": generating_args.repetition_penalty,
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
}
|
||||
@@ -133,8 +133,8 @@ with gr.Blocks() as demo:
|
||||
with gr.Column(scale=1):
|
||||
emptyBtn = gr.Button("Clear History")
|
||||
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user