fix int8 inference
Former-commit-id: d05202943e9634526f96d189288f67852d3d1c40
This commit is contained in:
@@ -15,15 +15,6 @@ require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
model.eval()
|
||||
|
||||
|
||||
"""Override Chatbot.postprocess"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user