allow multiple GPUs to do inference in a data parallel way

This commit is contained in:
Andrej Karpathy
2025-10-15 19:12:19 +00:00
parent 190d9515d0
commit 01fb290f53
2 changed files with 145 additions and 73 deletions

View File

@@ -327,7 +327,6 @@
}, },
body: JSON.stringify({ body: JSON.stringify({
messages: messages, messages: messages,
stream: true,
temperature: 0.8, temperature: 0.8,
max_tokens: 512 max_tokens: 512
}), }),

View File

@@ -1,26 +1,46 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Unified web chat server - serves both UI and API from a single FastAPI instance. Unified web chat server - serves both UI and API from a single FastAPI instance.
Run with: python web_chat.py
Then open http://localhost:8000 in your browser. Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
""" """
import argparse import argparse
import json import json
import os import os
import torch import torch
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='NanoChat Web Server') parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
@@ -32,7 +52,55 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args() args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
for gpu_id in range(self.num_gpus):
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
@@ -43,14 +111,13 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None temperature: Optional[float] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
top_k: Optional[int] = None top_k: Optional[int] = None
stream: Optional[bool] = True
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Load model on startup.""" """Load models on all GPUs on startup."""
print("Loading nanochat model...") print("Loading nanochat models across GPUs...")
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
app.state.engine = Engine(app.state.model, app.state.tokenizer) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}") print(f"Server ready at http://localhost:{args.port}")
yield yield
@@ -85,8 +152,7 @@ async def logo():
return FileResponse(logo_path, media_type="image/svg+xml") return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream( async def generate_stream(
engine, worker: Worker,
tokenizer,
tokens, tokens,
temperature=None, temperature=None,
max_new_tokens=None, max_new_tokens=None,
@@ -97,11 +163,11 @@ async def generate_stream(
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k top_k = top_k if top_k is not None else args.top_k
assistant_end = tokenizer.encode_special("<|assistant_end|>") assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = tokenizer.get_bos_token_id() bos = worker.tokenizer.get_bos_token_id()
with autocast_ctx: with worker.autocast_ctx:
for token_column, token_masks in engine.generate( for token_column, token_masks in worker.engine.generate(
tokens, tokens,
num_samples=1, num_samples=1,
max_tokens=max_new_tokens, max_tokens=max_new_tokens,
@@ -113,82 +179,89 @@ async def generate_stream(
if token == assistant_end or token == bos: if token == assistant_end or token == bos:
break break
token_text = tokenizer.decode([token]) token_text = worker.tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text})}\n\n" yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions") @app.post("/chat/completions")
async def chat_completions(request: ChatRequest): async def chat_completions(request: ChatRequest):
"""Chat completion endpoint with streaming.""" """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
engine = app.state.engine worker_pool = app.state.worker_pool
tokenizer = app.state.tokenizer
# Build conversation tokens # Acquire a worker from the pool (will wait if all are busy)
bos = tokenizer.get_bos_token_id() worker = await worker_pool.acquire_worker()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos] try:
for message in request.messages: # Build conversation tokens
if message.role == "user": bos = worker.tokenizer.get_bos_token_id()
conversation_tokens.append(user_start) user_start = worker.tokenizer.encode_special("<|user_start|>")
conversation_tokens.extend(tokenizer.encode(message.content)) user_end = worker.tokenizer.encode_special("<|user_end|>")
conversation_tokens.append(user_end) assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
elif message.role == "assistant": assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start) conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
async def stream_and_release():
try:
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
yield chunk
finally:
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
if request.stream:
return StreamingResponse( return StreamingResponse(
generate_stream( stream_and_release(),
engine,
tokenizer,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
),
media_type="text/event-stream" media_type="text/event-stream"
) )
else: except Exception as e:
# Non-streaming response # Make sure to release worker even on error
temperature = request.temperature if request.temperature is not None else args.temperature await worker_pool.release_worker(worker)
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens raise e
top_k = request.top_k if request.top_k is not None else args.top_k
with autocast_ctx:
result_tokens, masks = engine.generate_batch(
conversation_tokens,
num_samples=1,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k
)[0]
response_tokens = result_tokens[len(conversation_tokens):]
response_text = tokenizer.decode(response_tokens)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}]
}
@app.get("/health") @app.get("/health")
async def health(): async def health():
"""Health check endpoint.""" """Health check endpoint."""
worker_pool = getattr(app.state, 'worker_pool', None)
return { return {
"status": "ok", "status": "ok",
"ready": hasattr(app.state, 'model') and app.state.model is not None "ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
} }
if __name__ == "__main__": if __name__ == "__main__":