add very basic abuse prevention limits to chat_web so it's ok to host endpoints

This commit is contained in:
Andrej Karpathy
2025-10-15 19:42:54 +00:00
parent 01fb290f53
commit 52bfeea8bd

View File

@@ -20,6 +20,14 @@ Endpoints:
POST /chat/completions - Chat API (streaming only) POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization GET /stats - Worker pool statistics and GPU utilization
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Max tokens clamped to 1-4096
""" """
import argparse import argparse
@@ -28,7 +36,7 @@ import os
import torch import torch
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -39,6 +47,17 @@ from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine from nanochat.engine import Engine
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server') parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
@@ -112,6 +131,69 @@ class ChatRequest(BaseModel):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
top_k: Optional[int] = None top_k: Optional[int] = None
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
# Check number of messages
if len(request.messages) == 0:
raise HTTPException(status_code=400, detail="At least one message is required")
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
)
# Check individual message lengths and total conversation length
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
)
# Validate role values
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup.""" """Load models on all GPUs on startup."""
@@ -187,9 +269,12 @@ async def generate_stream(
@app.post("/chat/completions") @app.post("/chat/completions")
async def chat_completions(request: ChatRequest): async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
worker_pool = app.state.worker_pool
# Basic validation to prevent abuse
validate_chat_request(request)
# Acquire a worker from the pool (will wait if all are busy) # Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker() worker = await worker_pool.acquire_worker()
try: try: