diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 2643417..f8e807c 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -20,6 +20,14 @@ Endpoints: POST /chat/completions - Chat API (streaming only) GET /health - Health check with worker pool status GET /stats - Worker pool statistics and GPU utilization + +Abuse Prevention: + - Maximum 500 messages per request + - Maximum 8000 characters per message + - Maximum 32000 characters total conversation length + - Temperature clamped to 0.0-2.0 + - Top-k clamped to 1-200 + - Max tokens clamped to 1-4096 """ import argparse @@ -28,7 +36,7 @@ import os import torch import asyncio from contextlib import asynccontextmanager -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel @@ -39,6 +47,17 @@ from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine +# Abuse prevention limits +MAX_MESSAGES_PER_REQUEST = 500 +MAX_MESSAGE_LENGTH = 8000 +MAX_TOTAL_CONVERSATION_LENGTH = 32000 +MIN_TEMPERATURE = 0.0 +MAX_TEMPERATURE = 2.0 +MIN_TOP_K = 1 +MAX_TOP_K = 200 +MIN_MAX_TOKENS = 1 +MAX_MAX_TOKENS = 4096 + parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") @@ -112,6 +131,69 @@ class ChatRequest(BaseModel): max_tokens: Optional[int] = None top_k: Optional[int] = None +def validate_chat_request(request: ChatRequest): + """Validate chat request to prevent abuse.""" + # Check number of messages + if len(request.messages) == 0: + raise HTTPException(status_code=400, detail="At least one message is required") + if len(request.messages) > MAX_MESSAGES_PER_REQUEST: + raise HTTPException( + status_code=400, + detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request" + ) + + # Check individual message lengths and total conversation length + total_length = 0 + for i, message in enumerate(request.messages): + if not message.content: + raise HTTPException(status_code=400, detail=f"Message {i} has empty content") + + msg_length = len(message.content) + if msg_length > MAX_MESSAGE_LENGTH: + raise HTTPException( + status_code=400, + detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message" + ) + total_length += msg_length + + if total_length > MAX_TOTAL_CONVERSATION_LENGTH: + raise HTTPException( + status_code=400, + detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed" + ) + + # Validate role values + for i, message in enumerate(request.messages): + if message.role not in ["user", "assistant"]: + raise HTTPException( + status_code=400, + detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + ) + + # Validate temperature + if request.temperature is not None: + if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE): + raise HTTPException( + status_code=400, + detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}" + ) + + # Validate top_k + if request.top_k is not None: + if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K): + raise HTTPException( + status_code=400, + detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}" + ) + + # Validate max_tokens + if request.max_tokens is not None: + if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS): + raise HTTPException( + status_code=400, + detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}" + ) + @asynccontextmanager async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" @@ -187,9 +269,12 @@ async def generate_stream( @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" - worker_pool = app.state.worker_pool + + # Basic validation to prevent abuse + validate_chat_request(request) # Acquire a worker from the pool (will wait if all are busy) + worker_pool = app.state.worker_pool worker = await worker_pool.acquire_worker() try: