add very basic abuse prevention limits to chat_web so it's ok to host endpoints
This commit is contained in:
@@ -20,6 +20,14 @@ Endpoints:
|
|||||||
POST /chat/completions - Chat API (streaming only)
|
POST /chat/completions - Chat API (streaming only)
|
||||||
GET /health - Health check with worker pool status
|
GET /health - Health check with worker pool status
|
||||||
GET /stats - Worker pool statistics and GPU utilization
|
GET /stats - Worker pool statistics and GPU utilization
|
||||||
|
|
||||||
|
Abuse Prevention:
|
||||||
|
- Maximum 500 messages per request
|
||||||
|
- Maximum 8000 characters per message
|
||||||
|
- Maximum 32000 characters total conversation length
|
||||||
|
- Temperature clamped to 0.0-2.0
|
||||||
|
- Top-k clamped to 1-200
|
||||||
|
- Max tokens clamped to 1-4096
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -28,7 +36,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -39,6 +47,17 @@ from nanochat.common import compute_init
|
|||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
|
# Abuse prevention limits
|
||||||
|
MAX_MESSAGES_PER_REQUEST = 500
|
||||||
|
MAX_MESSAGE_LENGTH = 8000
|
||||||
|
MAX_TOTAL_CONVERSATION_LENGTH = 32000
|
||||||
|
MIN_TEMPERATURE = 0.0
|
||||||
|
MAX_TEMPERATURE = 2.0
|
||||||
|
MIN_TOP_K = 1
|
||||||
|
MAX_TOP_K = 200
|
||||||
|
MIN_MAX_TOKENS = 1
|
||||||
|
MAX_MAX_TOKENS = 4096
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||||
@@ -112,6 +131,69 @@ class ChatRequest(BaseModel):
|
|||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
|
||||||
|
def validate_chat_request(request: ChatRequest):
|
||||||
|
"""Validate chat request to prevent abuse."""
|
||||||
|
# Check number of messages
|
||||||
|
if len(request.messages) == 0:
|
||||||
|
raise HTTPException(status_code=400, detail="At least one message is required")
|
||||||
|
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check individual message lengths and total conversation length
|
||||||
|
total_length = 0
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if not message.content:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
|
||||||
|
|
||||||
|
msg_length = len(message.content)
|
||||||
|
if msg_length > MAX_MESSAGE_LENGTH:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
|
||||||
|
)
|
||||||
|
total_length += msg_length
|
||||||
|
|
||||||
|
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate role values
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if message.role not in ["user", "assistant"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate temperature
|
||||||
|
if request.temperature is not None:
|
||||||
|
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate top_k
|
||||||
|
if request.top_k is not None:
|
||||||
|
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate max_tokens
|
||||||
|
if request.max_tokens is not None:
|
||||||
|
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
||||||
|
)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Load models on all GPUs on startup."""
|
"""Load models on all GPUs on startup."""
|
||||||
@@ -187,9 +269,12 @@ async def generate_stream(
|
|||||||
@app.post("/chat/completions")
|
@app.post("/chat/completions")
|
||||||
async def chat_completions(request: ChatRequest):
|
async def chat_completions(request: ChatRequest):
|
||||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||||
worker_pool = app.state.worker_pool
|
|
||||||
|
# Basic validation to prevent abuse
|
||||||
|
validate_chat_request(request)
|
||||||
|
|
||||||
# Acquire a worker from the pool (will wait if all are busy)
|
# Acquire a worker from the pool (will wait if all are busy)
|
||||||
|
worker_pool = app.state.worker_pool
|
||||||
worker = await worker_pool.acquire_worker()
|
worker = await worker_pool.acquire_worker()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user