add very basic abuse prevention limits to chat_web so it's ok to host endpoints

This commit is contained in:
Andrej Karpathy
2025-10-15 19:42:54 +00:00
parent 01fb290f53
commit 52bfeea8bd

View File

@@ -20,6 +20,14 @@ Endpoints:
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Max tokens clamped to 1-4096
"""
import argparse
@@ -28,7 +36,7 @@ import os
import torch
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
@@ -39,6 +47,17 @@ from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
@@ -112,6 +131,69 @@ class ChatRequest(BaseModel):
max_tokens: Optional[int] = None
top_k: Optional[int] = None
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
# Check number of messages
if len(request.messages) == 0:
raise HTTPException(status_code=400, detail="At least one message is required")
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
)
# Check individual message lengths and total conversation length
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
)
# Validate role values
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
@@ -187,9 +269,12 @@ async def generate_stream(
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
worker_pool = app.state.worker_pool
# Basic validation to prevent abuse
validate_chat_request(request)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
try: