support BLOOM models
Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Implements user interface in browser for LLaMA fine-tuned with PEFT.
|
||||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -17,8 +17,8 @@ parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count())
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
@@ -111,7 +111,7 @@ def reset_state():
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
|
||||
gr.HTML("""<h1 align="center">LLaMA-Efficient-Tuning</h1>""")
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
with gr.Row():
|
||||
|
||||
Reference in New Issue
Block a user