Files
nanochat/scripts/chat_web.py
2025-10-15 19:12:19 +00:00

272 lines
9.8 KiB
Python

#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
"""
import argparse
import json
import os
import torch
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
for gpu_id in range(self.num_gpus):
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg")
async def logo():
"""Serve the NanoChat logo for favicon and header."""
logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
temperature = temperature if temperature is not None else args.temperature
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = worker.tokenizer.get_bos_token_id()
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k
):
token = token_column[0]
if token == assistant_end or token == bos:
break
token_text = worker.tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
worker_pool = app.state.worker_pool
# Acquire a worker from the pool (will wait if all are busy)
worker = await worker_pool.acquire_worker()
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
async def stream_and_release():
try:
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
yield chunk
finally:
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_and_release(),
media_type="text/event-stream"
)
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
}
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)