alter rewards data type
Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
@@ -4,22 +4,24 @@
|
||||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
|
||||
Reference in New Issue
Block a user