allow multiple GPUs to do inference in a data parallel way
This commit is contained in:
@@ -327,7 +327,6 @@
|
|||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
messages: messages,
|
messages: messages,
|
||||||
stream: true,
|
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
max_tokens: 512
|
max_tokens: 512
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -1,26 +1,46 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Unified web chat server - serves both UI and API from a single FastAPI instance.
|
Unified web chat server - serves both UI and API from a single FastAPI instance.
|
||||||
Run with: python web_chat.py
|
|
||||||
Then open http://localhost:8000 in your browser.
|
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
|
||||||
|
a full copy of the model, and incoming requests are distributed to available workers.
|
||||||
|
|
||||||
|
Launch examples:
|
||||||
|
|
||||||
|
- single available GPU (default)
|
||||||
|
python -m scripts.chat_web
|
||||||
|
|
||||||
|
- 4 GPUs
|
||||||
|
python -m scripts.chat_web --num-gpus 4
|
||||||
|
|
||||||
|
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET / - Chat UI
|
||||||
|
POST /chat/completions - Chat API (streaming only)
|
||||||
|
GET /health - Health check with worker pool status
|
||||||
|
GET /stats - Worker pool statistics and GPU utilization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from nanochat.common import compute_init
|
from nanochat.common import compute_init
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||||
|
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||||
@@ -32,7 +52,55 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
||||||
|
@dataclass
|
||||||
|
class Worker:
|
||||||
|
"""A worker with a model loaded on a specific GPU."""
|
||||||
|
gpu_id: int
|
||||||
|
device: torch.device
|
||||||
|
engine: Engine
|
||||||
|
tokenizer: object
|
||||||
|
autocast_ctx: torch.amp.autocast
|
||||||
|
|
||||||
|
class WorkerPool:
|
||||||
|
"""Pool of workers, each with a model replica on a different GPU."""
|
||||||
|
|
||||||
|
def __init__(self, num_gpus: Optional[int] = None):
|
||||||
|
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
||||||
|
self.workers: List[Worker] = []
|
||||||
|
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||||
|
"""Load model on each GPU."""
|
||||||
|
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||||
|
|
||||||
|
for gpu_id in range(self.num_gpus):
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
print(f"Loading model on GPU {gpu_id}...")
|
||||||
|
|
||||||
|
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||||
|
engine = Engine(model, tokenizer)
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
worker = Worker(
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
device=device,
|
||||||
|
engine=engine,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
autocast_ctx=autocast_ctx
|
||||||
|
)
|
||||||
|
self.workers.append(worker)
|
||||||
|
await self.available_workers.put(worker)
|
||||||
|
|
||||||
|
print(f"All {self.num_gpus} workers initialized!")
|
||||||
|
|
||||||
|
async def acquire_worker(self) -> Worker:
|
||||||
|
"""Get an available worker from the pool."""
|
||||||
|
return await self.available_workers.get()
|
||||||
|
|
||||||
|
async def release_worker(self, worker: Worker):
|
||||||
|
"""Return a worker to the pool."""
|
||||||
|
await self.available_workers.put(worker)
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
@@ -43,14 +111,13 @@ class ChatRequest(BaseModel):
|
|||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
stream: Optional[bool] = True
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Load model on startup."""
|
"""Load models on all GPUs on startup."""
|
||||||
print("Loading nanochat model...")
|
print("Loading nanochat models across GPUs...")
|
||||||
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||||
app.state.engine = Engine(app.state.model, app.state.tokenizer)
|
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||||
print(f"Server ready at http://localhost:{args.port}")
|
print(f"Server ready at http://localhost:{args.port}")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -85,8 +152,7 @@ async def logo():
|
|||||||
return FileResponse(logo_path, media_type="image/svg+xml")
|
return FileResponse(logo_path, media_type="image/svg+xml")
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
engine,
|
worker: Worker,
|
||||||
tokenizer,
|
|
||||||
tokens,
|
tokens,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
@@ -97,11 +163,11 @@ async def generate_stream(
|
|||||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||||
top_k = top_k if top_k is not None else args.top_k
|
top_k = top_k if top_k is not None else args.top_k
|
||||||
|
|
||||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||||
bos = tokenizer.get_bos_token_id()
|
bos = worker.tokenizer.get_bos_token_id()
|
||||||
|
|
||||||
with autocast_ctx:
|
with worker.autocast_ctx:
|
||||||
for token_column, token_masks in engine.generate(
|
for token_column, token_masks in worker.engine.generate(
|
||||||
tokens,
|
tokens,
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
max_tokens=max_new_tokens,
|
max_tokens=max_new_tokens,
|
||||||
@@ -113,82 +179,89 @@ async def generate_stream(
|
|||||||
if token == assistant_end or token == bos:
|
if token == assistant_end or token == bos:
|
||||||
break
|
break
|
||||||
|
|
||||||
token_text = tokenizer.decode([token])
|
token_text = worker.tokenizer.decode([token])
|
||||||
yield f"data: {json.dumps({'token': token_text})}\n\n"
|
yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n"
|
||||||
|
|
||||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||||
|
|
||||||
@app.post("/chat/completions")
|
@app.post("/chat/completions")
|
||||||
async def chat_completions(request: ChatRequest):
|
async def chat_completions(request: ChatRequest):
|
||||||
"""Chat completion endpoint with streaming."""
|
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||||
engine = app.state.engine
|
worker_pool = app.state.worker_pool
|
||||||
tokenizer = app.state.tokenizer
|
|
||||||
|
|
||||||
# Build conversation tokens
|
# Acquire a worker from the pool (will wait if all are busy)
|
||||||
bos = tokenizer.get_bos_token_id()
|
worker = await worker_pool.acquire_worker()
|
||||||
user_start = tokenizer.encode_special("<|user_start|>")
|
|
||||||
user_end = tokenizer.encode_special("<|user_end|>")
|
|
||||||
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
|
||||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
|
||||||
|
|
||||||
conversation_tokens = [bos]
|
try:
|
||||||
for message in request.messages:
|
# Build conversation tokens
|
||||||
if message.role == "user":
|
bos = worker.tokenizer.get_bos_token_id()
|
||||||
conversation_tokens.append(user_start)
|
user_start = worker.tokenizer.encode_special("<|user_start|>")
|
||||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
user_end = worker.tokenizer.encode_special("<|user_end|>")
|
||||||
conversation_tokens.append(user_end)
|
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
|
||||||
elif message.role == "assistant":
|
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||||
conversation_tokens.append(assistant_start)
|
|
||||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
|
||||||
conversation_tokens.append(assistant_end)
|
|
||||||
|
|
||||||
conversation_tokens.append(assistant_start)
|
conversation_tokens = [bos]
|
||||||
|
for message in request.messages:
|
||||||
|
if message.role == "user":
|
||||||
|
conversation_tokens.append(user_start)
|
||||||
|
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(user_end)
|
||||||
|
elif message.role == "assistant":
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(assistant_end)
|
||||||
|
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
|
||||||
|
# Streaming response with worker release after completion
|
||||||
|
async def stream_and_release():
|
||||||
|
try:
|
||||||
|
async for chunk in generate_stream(
|
||||||
|
worker,
|
||||||
|
conversation_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
top_k=request.top_k
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
# Release worker back to pool after streaming is done
|
||||||
|
await worker_pool.release_worker(worker)
|
||||||
|
|
||||||
if request.stream:
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream(
|
stream_and_release(),
|
||||||
engine,
|
|
||||||
tokenizer,
|
|
||||||
conversation_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
max_new_tokens=request.max_tokens,
|
|
||||||
top_k=request.top_k
|
|
||||||
),
|
|
||||||
media_type="text/event-stream"
|
media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
else:
|
except Exception as e:
|
||||||
# Non-streaming response
|
# Make sure to release worker even on error
|
||||||
temperature = request.temperature if request.temperature is not None else args.temperature
|
await worker_pool.release_worker(worker)
|
||||||
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
|
raise e
|
||||||
top_k = request.top_k if request.top_k is not None else args.top_k
|
|
||||||
|
|
||||||
with autocast_ctx:
|
|
||||||
result_tokens, masks = engine.generate_batch(
|
|
||||||
conversation_tokens,
|
|
||||||
num_samples=1,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
response_tokens = result_tokens[len(conversation_tokens):]
|
|
||||||
response_text = tokenizer.decode(response_tokens)
|
|
||||||
return {
|
|
||||||
"choices": [{
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response_text
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
|
worker_pool = getattr(app.state, 'worker_pool', None)
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"ready": hasattr(app.state, 'model') and app.state.model is not None
|
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
||||||
|
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
||||||
|
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
async def stats():
|
||||||
|
"""Get worker pool statistics."""
|
||||||
|
worker_pool = app.state.worker_pool
|
||||||
|
return {
|
||||||
|
"total_workers": len(worker_pool.workers),
|
||||||
|
"available_workers": worker_pool.available_workers.qsize(),
|
||||||
|
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||||
|
"workers": [
|
||||||
|
{
|
||||||
|
"gpu_id": w.gpu_id,
|
||||||
|
"device": str(w.device)
|
||||||
|
} for w in worker_pool.workers
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user