alter rewards data type

Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent 896dbfec16
commit e9ab06678f
12 changed files with 40 additions and 50 deletions

View File

@@ -7,21 +7,23 @@ import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, load_pretrained
from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
@@ -74,10 +76,10 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
def format_example(query):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
def predict(input, chatbot, max_length, top_p, temperature, history):