mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-30 22:32:06 +00:00
basic ui
This commit is contained in:
8
server/__init__.py
Normal file
8
server/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
FastAPI Backend Server
|
||||
======================
|
||||
|
||||
Web UI server for the Autonomous Coding Agent.
|
||||
Provides REST API and WebSocket endpoints for project management,
|
||||
feature tracking, and agent control.
|
||||
"""
|
||||
171
server/main.py
Normal file
171
server/main.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
FastAPI Main Application
|
||||
========================
|
||||
|
||||
Main entry point for the Autonomous Coding UI server.
|
||||
Provides REST API, WebSocket, and static file serving.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request, WebSocket, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .routers import projects_router, features_router, agent_router
|
||||
from .websocket import project_websocket
|
||||
from .services.process_manager import cleanup_all_managers
|
||||
from .schemas import SetupStatus
|
||||
|
||||
|
||||
# Paths
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
UI_DIST_DIR = ROOT_DIR / "ui" / "dist"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown."""
|
||||
# Startup
|
||||
yield
|
||||
# Shutdown - cleanup all running agents
|
||||
await cleanup_all_managers()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Autonomous Coding UI",
|
||||
description="Web UI for the Autonomous Coding Agent",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS - allow only localhost origins for security
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite dev server
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:8000", # Production
|
||||
"http://127.0.0.1:8000",
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Security Middleware
|
||||
# ============================================================================
|
||||
|
||||
@app.middleware("http")
|
||||
async def require_localhost(request: Request, call_next):
|
||||
"""Only allow requests from localhost."""
|
||||
client_host = request.client.host if request.client else None
|
||||
|
||||
# Allow localhost connections
|
||||
if client_host not in ("127.0.0.1", "::1", "localhost", None):
|
||||
raise HTTPException(status_code=403, detail="Localhost access only")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Include Routers
|
||||
# ============================================================================
|
||||
|
||||
app.include_router(projects_router)
|
||||
app.include_router(features_router)
|
||||
app.include_router(agent_router)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@app.websocket("/ws/projects/{project_name}")
|
||||
async def websocket_endpoint(websocket: WebSocket, project_name: str):
|
||||
"""WebSocket endpoint for real-time project updates."""
|
||||
await project_websocket(websocket, project_name)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Setup & Health Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/api/setup/status", response_model=SetupStatus)
|
||||
async def setup_status():
|
||||
"""Check system setup status."""
|
||||
# Check for Claude CLI
|
||||
claude_cli = shutil.which("claude") is not None
|
||||
|
||||
# Check for credentials file
|
||||
credentials_path = Path.home() / ".claude" / ".credentials.json"
|
||||
credentials = credentials_path.exists()
|
||||
|
||||
# Check for Node.js and npm
|
||||
node = shutil.which("node") is not None
|
||||
npm = shutil.which("npm") is not None
|
||||
|
||||
return SetupStatus(
|
||||
claude_cli=claude_cli,
|
||||
credentials=credentials,
|
||||
node=node,
|
||||
npm=npm,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Static File Serving (Production)
|
||||
# ============================================================================
|
||||
|
||||
# Serve React build files if they exist
|
||||
if UI_DIST_DIR.exists():
|
||||
# Mount static assets
|
||||
app.mount("/assets", StaticFiles(directory=UI_DIST_DIR / "assets"), name="assets")
|
||||
|
||||
@app.get("/")
|
||||
async def serve_index():
|
||||
"""Serve the React app index.html."""
|
||||
return FileResponse(UI_DIST_DIR / "index.html")
|
||||
|
||||
@app.get("/{path:path}")
|
||||
async def serve_spa(path: str):
|
||||
"""
|
||||
Serve static files or fall back to index.html for SPA routing.
|
||||
"""
|
||||
# Check if the path is an API route (shouldn't hit this due to router ordering)
|
||||
if path.startswith("api/") or path.startswith("ws/"):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
# Try to serve the file directly
|
||||
file_path = UI_DIST_DIR / path
|
||||
if file_path.exists() and file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
# Fall back to index.html for SPA routing
|
||||
return FileResponse(UI_DIST_DIR / "index.html")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Entry Point
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"server.main:app",
|
||||
host="127.0.0.1", # Localhost only for security
|
||||
port=8000,
|
||||
reload=True,
|
||||
)
|
||||
12
server/routers/__init__.py
Normal file
12
server/routers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
API Routers
|
||||
===========
|
||||
|
||||
FastAPI routers for different API endpoints.
|
||||
"""
|
||||
|
||||
from .projects import router as projects_router
|
||||
from .features import router as features_router
|
||||
from .agent import router as agent_router
|
||||
|
||||
__all__ = ["projects_router", "features_router", "agent_router"]
|
||||
128
server/routers/agent.py
Normal file
128
server/routers/agent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Agent Router
|
||||
============
|
||||
|
||||
API endpoints for agent control (start/stop/pause/resume).
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from ..schemas import AgentStatus, AgentActionResponse
|
||||
from ..services.process_manager import get_manager
|
||||
|
||||
# Lazy import to avoid sys.path manipulation at module level
|
||||
_GENERATIONS_DIR = None
|
||||
|
||||
|
||||
def _get_generations_dir():
|
||||
"""Lazy import of GENERATIONS_DIR."""
|
||||
global _GENERATIONS_DIR
|
||||
if _GENERATIONS_DIR is None:
|
||||
import sys
|
||||
root = Path(__file__).parent.parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from start import GENERATIONS_DIR
|
||||
_GENERATIONS_DIR = GENERATIONS_DIR
|
||||
return _GENERATIONS_DIR
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/projects/{project_name}/agent", tags=["agent"])
|
||||
|
||||
# Root directory for process manager
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> str:
|
||||
"""Validate and sanitize project name to prevent path traversal."""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid project name"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def get_project_manager(project_name: str):
|
||||
"""Get the process manager for a project."""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
return get_manager(project_name, ROOT_DIR)
|
||||
|
||||
|
||||
@router.get("/status", response_model=AgentStatus)
|
||||
async def get_agent_status(project_name: str):
|
||||
"""Get the current status of the agent for a project."""
|
||||
manager = get_project_manager(project_name)
|
||||
|
||||
# Run healthcheck to detect crashed processes
|
||||
await manager.healthcheck()
|
||||
|
||||
return AgentStatus(
|
||||
status=manager.status,
|
||||
pid=manager.pid,
|
||||
started_at=manager.started_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start", response_model=AgentActionResponse)
|
||||
async def start_agent(project_name: str):
|
||||
"""Start the agent for a project."""
|
||||
manager = get_project_manager(project_name)
|
||||
|
||||
success, message = await manager.start()
|
||||
|
||||
return AgentActionResponse(
|
||||
success=success,
|
||||
status=manager.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop", response_model=AgentActionResponse)
|
||||
async def stop_agent(project_name: str):
|
||||
"""Stop the agent for a project."""
|
||||
manager = get_project_manager(project_name)
|
||||
|
||||
success, message = await manager.stop()
|
||||
|
||||
return AgentActionResponse(
|
||||
success=success,
|
||||
status=manager.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pause", response_model=AgentActionResponse)
|
||||
async def pause_agent(project_name: str):
|
||||
"""Pause the agent for a project."""
|
||||
manager = get_project_manager(project_name)
|
||||
|
||||
success, message = await manager.pause()
|
||||
|
||||
return AgentActionResponse(
|
||||
success=success,
|
||||
status=manager.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume", response_model=AgentActionResponse)
|
||||
async def resume_agent(project_name: str):
|
||||
"""Resume a paused agent."""
|
||||
manager = get_project_manager(project_name)
|
||||
|
||||
success, message = await manager.resume()
|
||||
|
||||
return AgentActionResponse(
|
||||
success=success,
|
||||
status=manager.status,
|
||||
message=message,
|
||||
)
|
||||
282
server/routers/features.py
Normal file
282
server/routers/features.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Features Router
|
||||
===============
|
||||
|
||||
API endpoints for feature/test case management.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from ..schemas import (
|
||||
FeatureCreate,
|
||||
FeatureResponse,
|
||||
FeatureListResponse,
|
||||
)
|
||||
|
||||
# Lazy imports to avoid circular dependencies
|
||||
_GENERATIONS_DIR = None
|
||||
_create_database = None
|
||||
_Feature = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_generations_dir():
|
||||
"""Lazy import of GENERATIONS_DIR."""
|
||||
global _GENERATIONS_DIR
|
||||
if _GENERATIONS_DIR is None:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root = Path(__file__).parent.parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from start import GENERATIONS_DIR
|
||||
_GENERATIONS_DIR = GENERATIONS_DIR
|
||||
return _GENERATIONS_DIR
|
||||
|
||||
|
||||
def _get_db_classes():
|
||||
"""Lazy import of database classes."""
|
||||
global _create_database, _Feature
|
||||
if _create_database is None:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root = Path(__file__).parent.parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from api.database import create_database, Feature
|
||||
_create_database = create_database
|
||||
_Feature = Feature
|
||||
return _create_database, _Feature
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> str:
|
||||
"""Validate and sanitize project name to prevent path traversal."""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid project name"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session(project_dir: Path):
|
||||
"""
|
||||
Context manager for database sessions.
|
||||
Ensures session is always closed, even on exceptions.
|
||||
"""
|
||||
create_database, _ = _get_db_classes()
|
||||
_, SessionLocal = create_database(project_dir)
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def feature_to_response(f) -> FeatureResponse:
|
||||
"""Convert a Feature model to a FeatureResponse."""
|
||||
return FeatureResponse(
|
||||
id=f.id,
|
||||
priority=f.priority,
|
||||
category=f.category,
|
||||
name=f.name,
|
||||
description=f.description,
|
||||
steps=f.steps if isinstance(f.steps, list) else [],
|
||||
passes=f.passes,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=FeatureListResponse)
|
||||
async def list_features(project_name: str):
|
||||
"""
|
||||
List all features for a project organized by status.
|
||||
|
||||
Returns features in three lists:
|
||||
- pending: passes=False, not currently being worked on
|
||||
- in_progress: features currently being worked on (tracked via agent output)
|
||||
- done: passes=True
|
||||
"""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
db_file = project_dir / "features.db"
|
||||
if not db_file.exists():
|
||||
return FeatureListResponse(pending=[], in_progress=[], done=[])
|
||||
|
||||
_, Feature = _get_db_classes()
|
||||
|
||||
try:
|
||||
with get_db_session(project_dir) as session:
|
||||
all_features = session.query(Feature).order_by(Feature.priority).all()
|
||||
|
||||
pending = []
|
||||
done = []
|
||||
|
||||
for f in all_features:
|
||||
feature_response = feature_to_response(f)
|
||||
if f.passes:
|
||||
done.append(feature_response)
|
||||
else:
|
||||
pending.append(feature_response)
|
||||
|
||||
return FeatureListResponse(
|
||||
pending=pending,
|
||||
in_progress=[],
|
||||
done=done,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Database error in list_features")
|
||||
raise HTTPException(status_code=500, detail="Database error occurred")
|
||||
|
||||
|
||||
@router.post("", response_model=FeatureResponse)
|
||||
async def create_feature(project_name: str, feature: FeatureCreate):
|
||||
"""Create a new feature/test case manually."""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
_, Feature = _get_db_classes()
|
||||
|
||||
try:
|
||||
with get_db_session(project_dir) as session:
|
||||
# Get next priority if not specified
|
||||
if feature.priority is None:
|
||||
max_priority = session.query(Feature).order_by(Feature.priority.desc()).first()
|
||||
priority = (max_priority.priority + 1) if max_priority else 1
|
||||
else:
|
||||
priority = feature.priority
|
||||
|
||||
# Create new feature
|
||||
db_feature = Feature(
|
||||
priority=priority,
|
||||
category=feature.category,
|
||||
name=feature.name,
|
||||
description=feature.description,
|
||||
steps=feature.steps,
|
||||
passes=False,
|
||||
)
|
||||
|
||||
session.add(db_feature)
|
||||
session.commit()
|
||||
session.refresh(db_feature)
|
||||
|
||||
return feature_to_response(db_feature)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create feature")
|
||||
raise HTTPException(status_code=500, detail="Failed to create feature")
|
||||
|
||||
|
||||
@router.get("/{feature_id}", response_model=FeatureResponse)
|
||||
async def get_feature(project_name: str, feature_id: int):
|
||||
"""Get details of a specific feature."""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
db_file = project_dir / "features.db"
|
||||
if not db_file.exists():
|
||||
raise HTTPException(status_code=404, detail="No features database found")
|
||||
|
||||
_, Feature = _get_db_classes()
|
||||
|
||||
try:
|
||||
with get_db_session(project_dir) as session:
|
||||
feature = session.query(Feature).filter(Feature.id == feature_id).first()
|
||||
|
||||
if not feature:
|
||||
raise HTTPException(status_code=404, detail=f"Feature {feature_id} not found")
|
||||
|
||||
return feature_to_response(feature)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Database error in get_feature")
|
||||
raise HTTPException(status_code=500, detail="Database error occurred")
|
||||
|
||||
|
||||
@router.delete("/{feature_id}")
|
||||
async def delete_feature(project_name: str, feature_id: int):
|
||||
"""Delete a feature."""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
_, Feature = _get_db_classes()
|
||||
|
||||
try:
|
||||
with get_db_session(project_dir) as session:
|
||||
feature = session.query(Feature).filter(Feature.id == feature_id).first()
|
||||
|
||||
if not feature:
|
||||
raise HTTPException(status_code=404, detail=f"Feature {feature_id} not found")
|
||||
|
||||
session.delete(feature)
|
||||
session.commit()
|
||||
|
||||
return {"success": True, "message": f"Feature {feature_id} deleted"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete feature")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete feature")
|
||||
|
||||
|
||||
@router.patch("/{feature_id}/skip")
|
||||
async def skip_feature(project_name: str, feature_id: int):
|
||||
"""
|
||||
Mark a feature as skipped by moving it to the end of the priority queue.
|
||||
|
||||
This doesn't delete the feature but gives it a very high priority number
|
||||
so it will be processed last.
|
||||
"""
|
||||
project_name = validate_project_name(project_name)
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{project_name}' not found")
|
||||
|
||||
_, Feature = _get_db_classes()
|
||||
|
||||
try:
|
||||
with get_db_session(project_dir) as session:
|
||||
feature = session.query(Feature).filter(Feature.id == feature_id).first()
|
||||
|
||||
if not feature:
|
||||
raise HTTPException(status_code=404, detail=f"Feature {feature_id} not found")
|
||||
|
||||
# Set priority to max + 1000 to push to end
|
||||
max_priority = session.query(Feature).order_by(Feature.priority.desc()).first()
|
||||
feature.priority = (max_priority.priority if max_priority else 0) + 1000
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"success": True, "message": f"Feature {feature_id} moved to end of queue"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to skip feature")
|
||||
raise HTTPException(status_code=500, detail="Failed to skip feature")
|
||||
239
server/routers/projects.py
Normal file
239
server/routers/projects.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Projects Router
|
||||
===============
|
||||
|
||||
API endpoints for project management.
|
||||
"""
|
||||
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from ..schemas import (
|
||||
ProjectCreate,
|
||||
ProjectSummary,
|
||||
ProjectDetail,
|
||||
ProjectPrompts,
|
||||
ProjectPromptsUpdate,
|
||||
ProjectStats,
|
||||
)
|
||||
|
||||
# Lazy imports to avoid sys.path manipulation at module level
|
||||
_imports_initialized = False
|
||||
_GENERATIONS_DIR = None
|
||||
_get_existing_projects = None
|
||||
_check_spec_exists = None
|
||||
_scaffold_project_prompts = None
|
||||
_get_project_prompts_dir = None
|
||||
_count_passing_tests = None
|
||||
|
||||
|
||||
def _init_imports():
|
||||
"""Lazy import of project-level modules."""
|
||||
global _imports_initialized, _GENERATIONS_DIR, _get_existing_projects
|
||||
global _check_spec_exists, _scaffold_project_prompts, _get_project_prompts_dir
|
||||
global _count_passing_tests
|
||||
|
||||
if _imports_initialized:
|
||||
return
|
||||
|
||||
import sys
|
||||
root = Path(__file__).parent.parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
|
||||
from start import GENERATIONS_DIR, get_existing_projects, check_spec_exists
|
||||
from prompts import scaffold_project_prompts, get_project_prompts_dir
|
||||
from progress import count_passing_tests
|
||||
|
||||
_GENERATIONS_DIR = GENERATIONS_DIR
|
||||
_get_existing_projects = get_existing_projects
|
||||
_check_spec_exists = check_spec_exists
|
||||
_scaffold_project_prompts = scaffold_project_prompts
|
||||
_get_project_prompts_dir = get_project_prompts_dir
|
||||
_count_passing_tests = count_passing_tests
|
||||
_imports_initialized = True
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/projects", tags=["projects"])
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> str:
|
||||
"""Validate and sanitize project name to prevent path traversal."""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def get_project_stats(project_dir: Path) -> ProjectStats:
|
||||
"""Get statistics for a project."""
|
||||
_init_imports()
|
||||
passing, total = _count_passing_tests(project_dir)
|
||||
percentage = (passing / total * 100) if total > 0 else 0.0
|
||||
return ProjectStats(passing=passing, total=total, percentage=round(percentage, 1))
|
||||
|
||||
|
||||
@router.get("", response_model=list[ProjectSummary])
|
||||
async def list_projects():
|
||||
"""List all projects in the generations directory."""
|
||||
_init_imports()
|
||||
projects = _get_existing_projects()
|
||||
result = []
|
||||
|
||||
for name in projects:
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
has_spec = _check_spec_exists(project_dir)
|
||||
stats = get_project_stats(project_dir)
|
||||
|
||||
result.append(ProjectSummary(
|
||||
name=name,
|
||||
has_spec=has_spec,
|
||||
stats=stats,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectSummary)
|
||||
async def create_project(project: ProjectCreate):
|
||||
"""Create a new project with scaffolded prompts."""
|
||||
_init_imports()
|
||||
name = validate_project_name(project.name)
|
||||
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if project_dir.exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Project '{name}' already exists"
|
||||
)
|
||||
|
||||
# Create project directory
|
||||
project_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Scaffold prompts
|
||||
_scaffold_project_prompts(project_dir)
|
||||
|
||||
return ProjectSummary(
|
||||
name=name,
|
||||
has_spec=False, # Just created, no spec yet
|
||||
stats=ProjectStats(passing=0, total=0, percentage=0.0),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{name}", response_model=ProjectDetail)
|
||||
async def get_project(name: str):
|
||||
"""Get detailed information about a project."""
|
||||
_init_imports()
|
||||
name = validate_project_name(name)
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
|
||||
has_spec = _check_spec_exists(project_dir)
|
||||
stats = get_project_stats(project_dir)
|
||||
prompts_dir = _get_project_prompts_dir(project_dir)
|
||||
|
||||
return ProjectDetail(
|
||||
name=name,
|
||||
has_spec=has_spec,
|
||||
stats=stats,
|
||||
prompts_dir=str(prompts_dir),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{name}")
|
||||
async def delete_project(name: str):
|
||||
"""Delete a project and all its files."""
|
||||
_init_imports()
|
||||
name = validate_project_name(name)
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
|
||||
# Check if agent is running
|
||||
lock_file = project_dir / ".agent.lock"
|
||||
if lock_file.exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete project while agent is running. Stop the agent first."
|
||||
)
|
||||
|
||||
try:
|
||||
shutil.rmtree(project_dir)
|
||||
return {"success": True, "message": f"Project '{name}' deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete project: {e}")
|
||||
|
||||
|
||||
@router.get("/{name}/prompts", response_model=ProjectPrompts)
|
||||
async def get_project_prompts(name: str):
|
||||
"""Get the content of project prompt files."""
|
||||
_init_imports()
|
||||
name = validate_project_name(name)
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
|
||||
prompts_dir = _get_project_prompts_dir(project_dir)
|
||||
|
||||
def read_file(filename: str) -> str:
|
||||
filepath = prompts_dir / filename
|
||||
if filepath.exists():
|
||||
try:
|
||||
return filepath.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
return ProjectPrompts(
|
||||
app_spec=read_file("app_spec.txt"),
|
||||
initializer_prompt=read_file("initializer_prompt.md"),
|
||||
coding_prompt=read_file("coding_prompt.md"),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{name}/prompts")
|
||||
async def update_project_prompts(name: str, prompts: ProjectPromptsUpdate):
|
||||
"""Update project prompt files."""
|
||||
_init_imports()
|
||||
name = validate_project_name(name)
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
|
||||
prompts_dir = _get_project_prompts_dir(project_dir)
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def write_file(filename: str, content: str | None):
|
||||
if content is not None:
|
||||
filepath = prompts_dir / filename
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
write_file("app_spec.txt", prompts.app_spec)
|
||||
write_file("initializer_prompt.md", prompts.initializer_prompt)
|
||||
write_file("coding_prompt.md", prompts.coding_prompt)
|
||||
|
||||
return {"success": True, "message": "Prompts updated"}
|
||||
|
||||
|
||||
@router.get("/{name}/stats", response_model=ProjectStats)
|
||||
async def get_project_stats_endpoint(name: str):
|
||||
"""Get current progress statistics for a project."""
|
||||
_init_imports()
|
||||
name = validate_project_name(name)
|
||||
project_dir = _GENERATIONS_DIR / name
|
||||
|
||||
if not project_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
|
||||
return get_project_stats(project_dir)
|
||||
152
server/schemas.py
Normal file
152
server/schemas.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Pydantic Schemas
|
||||
================
|
||||
|
||||
Request/Response models for the API endpoints.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Project Schemas
|
||||
# ============================================================================
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
"""Request schema for creating a new project."""
|
||||
name: str = Field(..., min_length=1, max_length=50, pattern=r'^[a-zA-Z0-9_-]+$')
|
||||
spec_method: Literal["claude", "manual"] = "claude"
|
||||
|
||||
|
||||
class ProjectStats(BaseModel):
|
||||
"""Project statistics."""
|
||||
passing: int = 0
|
||||
total: int = 0
|
||||
percentage: float = 0.0
|
||||
|
||||
|
||||
class ProjectSummary(BaseModel):
|
||||
"""Summary of a project for list view."""
|
||||
name: str
|
||||
has_spec: bool
|
||||
stats: ProjectStats
|
||||
|
||||
|
||||
class ProjectDetail(BaseModel):
|
||||
"""Detailed project information."""
|
||||
name: str
|
||||
has_spec: bool
|
||||
stats: ProjectStats
|
||||
prompts_dir: str
|
||||
|
||||
|
||||
class ProjectPrompts(BaseModel):
|
||||
"""Project prompt files content."""
|
||||
app_spec: str = ""
|
||||
initializer_prompt: str = ""
|
||||
coding_prompt: str = ""
|
||||
|
||||
|
||||
class ProjectPromptsUpdate(BaseModel):
|
||||
"""Request schema for updating project prompts."""
|
||||
app_spec: str | None = None
|
||||
initializer_prompt: str | None = None
|
||||
coding_prompt: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Feature Schemas
|
||||
# ============================================================================
|
||||
|
||||
class FeatureBase(BaseModel):
|
||||
"""Base feature attributes."""
|
||||
category: str
|
||||
name: str
|
||||
description: str
|
||||
steps: list[str]
|
||||
|
||||
|
||||
class FeatureCreate(FeatureBase):
|
||||
"""Request schema for creating a new feature."""
|
||||
priority: int | None = None
|
||||
|
||||
|
||||
class FeatureResponse(FeatureBase):
|
||||
"""Response schema for a feature."""
|
||||
id: int
|
||||
priority: int
|
||||
passes: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FeatureListResponse(BaseModel):
|
||||
"""Response containing list of features organized by status."""
|
||||
pending: list[FeatureResponse]
|
||||
in_progress: list[FeatureResponse]
|
||||
done: list[FeatureResponse]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Schemas
|
||||
# ============================================================================
|
||||
|
||||
class AgentStatus(BaseModel):
|
||||
"""Current agent status."""
|
||||
status: Literal["stopped", "running", "paused", "crashed"]
|
||||
pid: int | None = None
|
||||
started_at: datetime | None = None
|
||||
|
||||
|
||||
class AgentActionResponse(BaseModel):
|
||||
"""Response for agent control actions."""
|
||||
success: bool
|
||||
status: str
|
||||
message: str = ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Setup Schemas
|
||||
# ============================================================================
|
||||
|
||||
class SetupStatus(BaseModel):
|
||||
"""System setup status."""
|
||||
claude_cli: bool
|
||||
credentials: bool
|
||||
node: bool
|
||||
npm: bool
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Message Schemas
|
||||
# ============================================================================
|
||||
|
||||
class WSProgressMessage(BaseModel):
|
||||
"""WebSocket message for progress updates."""
|
||||
type: Literal["progress"] = "progress"
|
||||
passing: int
|
||||
total: int
|
||||
percentage: float
|
||||
|
||||
|
||||
class WSFeatureUpdateMessage(BaseModel):
|
||||
"""WebSocket message for feature status updates."""
|
||||
type: Literal["feature_update"] = "feature_update"
|
||||
feature_id: int
|
||||
passes: bool
|
||||
|
||||
|
||||
class WSLogMessage(BaseModel):
|
||||
"""WebSocket message for agent log output."""
|
||||
type: Literal["log"] = "log"
|
||||
line: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class WSAgentStatusMessage(BaseModel):
|
||||
"""WebSocket message for agent status changes."""
|
||||
type: Literal["agent_status"] = "agent_status"
|
||||
status: str
|
||||
10
server/services/__init__.py
Normal file
10
server/services/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Backend Services
|
||||
================
|
||||
|
||||
Business logic and process management services.
|
||||
"""
|
||||
|
||||
from .process_manager import AgentProcessManager
|
||||
|
||||
__all__ = ["AgentProcessManager"]
|
||||
403
server/services/process_manager.py
Normal file
403
server/services/process_manager.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Agent Process Manager
|
||||
=====================
|
||||
|
||||
Manages the lifecycle of agent subprocesses per project.
|
||||
Provides start/stop/pause/resume functionality with cross-platform support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal, Callable, Awaitable, Set
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns for sensitive data that should be redacted from output
|
||||
SENSITIVE_PATTERNS = [
|
||||
r'sk-[a-zA-Z0-9]{20,}', # Anthropic API keys
|
||||
r'ANTHROPIC_API_KEY=[^\s]+',
|
||||
r'api[_-]?key[=:][^\s]+',
|
||||
r'token[=:][^\s]+',
|
||||
r'password[=:][^\s]+',
|
||||
r'secret[=:][^\s]+',
|
||||
r'ghp_[a-zA-Z0-9]{36,}', # GitHub personal access tokens
|
||||
r'gho_[a-zA-Z0-9]{36,}', # GitHub OAuth tokens
|
||||
r'ghs_[a-zA-Z0-9]{36,}', # GitHub server tokens
|
||||
r'ghr_[a-zA-Z0-9]{36,}', # GitHub refresh tokens
|
||||
r'aws[_-]?access[_-]?key[=:][^\s]+', # AWS keys
|
||||
r'aws[_-]?secret[=:][^\s]+',
|
||||
]
|
||||
|
||||
|
||||
def sanitize_output(line: str) -> str:
|
||||
"""Remove sensitive information from output lines."""
|
||||
for pattern in SENSITIVE_PATTERNS:
|
||||
line = re.sub(pattern, '[REDACTED]', line, flags=re.IGNORECASE)
|
||||
return line
|
||||
|
||||
|
||||
class AgentProcessManager:
|
||||
"""
|
||||
Manages agent subprocess lifecycle for a single project.
|
||||
|
||||
Provides start/stop/pause/resume with cross-platform support via psutil.
|
||||
Supports multiple output callbacks for WebSocket clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_name: str,
|
||||
root_dir: Path,
|
||||
):
|
||||
"""
|
||||
Initialize the process manager.
|
||||
|
||||
Args:
|
||||
project_name: Name of the project
|
||||
root_dir: Root directory of the autonomous-coding-ui project
|
||||
"""
|
||||
self.project_name = project_name
|
||||
self.root_dir = root_dir
|
||||
self.process: subprocess.Popen | None = None
|
||||
self._status: Literal["stopped", "running", "paused", "crashed"] = "stopped"
|
||||
self.started_at: datetime | None = None
|
||||
self._output_task: asyncio.Task | None = None
|
||||
|
||||
# Support multiple callbacks (for multiple WebSocket clients)
|
||||
self._output_callbacks: Set[Callable[[str], Awaitable[None]]] = set()
|
||||
self._status_callbacks: Set[Callable[[str], Awaitable[None]]] = set()
|
||||
self._callbacks_lock = threading.Lock()
|
||||
|
||||
# Lock file to prevent multiple instances
|
||||
self.lock_file = self.root_dir / "generations" / project_name / ".agent.lock"
|
||||
|
||||
@property
|
||||
def status(self) -> Literal["stopped", "running", "paused", "crashed"]:
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: Literal["stopped", "running", "paused", "crashed"]):
|
||||
old_status = self._status
|
||||
self._status = value
|
||||
if old_status != value:
|
||||
self._notify_status_change(value)
|
||||
|
||||
def _notify_status_change(self, status: str) -> None:
|
||||
"""Notify all registered callbacks of status change."""
|
||||
with self._callbacks_lock:
|
||||
callbacks = list(self._status_callbacks)
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
# Schedule the callback in the event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._safe_callback(callback, status))
|
||||
except RuntimeError:
|
||||
# No running event loop
|
||||
pass
|
||||
|
||||
async def _safe_callback(self, callback: Callable, *args) -> None:
|
||||
"""Safely execute a callback, catching and logging any errors."""
|
||||
try:
|
||||
await callback(*args)
|
||||
except Exception as e:
|
||||
logger.warning(f"Callback error: {e}")
|
||||
|
||||
def add_output_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
||||
"""Add a callback for output lines."""
|
||||
with self._callbacks_lock:
|
||||
self._output_callbacks.add(callback)
|
||||
|
||||
def remove_output_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
||||
"""Remove an output callback."""
|
||||
with self._callbacks_lock:
|
||||
self._output_callbacks.discard(callback)
|
||||
|
||||
def add_status_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
||||
"""Add a callback for status changes."""
|
||||
with self._callbacks_lock:
|
||||
self._status_callbacks.add(callback)
|
||||
|
||||
def remove_status_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
||||
"""Remove a status callback."""
|
||||
with self._callbacks_lock:
|
||||
self._status_callbacks.discard(callback)
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
return self.process.pid if self.process else None
|
||||
|
||||
def _check_lock(self) -> bool:
|
||||
"""Check if another agent is already running for this project."""
|
||||
if not self.lock_file.exists():
|
||||
return True
|
||||
|
||||
try:
|
||||
pid = int(self.lock_file.read_text().strip())
|
||||
if psutil.pid_exists(pid):
|
||||
# Check if it's actually our agent process
|
||||
try:
|
||||
proc = psutil.Process(pid)
|
||||
cmdline = " ".join(proc.cmdline())
|
||||
if "autonomous_agent_demo.py" in cmdline:
|
||||
return False # Another agent is running
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
# Stale lock file
|
||||
self.lock_file.unlink(missing_ok=True)
|
||||
return True
|
||||
except (ValueError, OSError):
|
||||
self.lock_file.unlink(missing_ok=True)
|
||||
return True
|
||||
|
||||
def _create_lock(self) -> None:
|
||||
"""Create lock file with current process PID."""
|
||||
self.lock_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.process:
|
||||
self.lock_file.write_text(str(self.process.pid))
|
||||
|
||||
def _remove_lock(self) -> None:
|
||||
"""Remove lock file."""
|
||||
self.lock_file.unlink(missing_ok=True)
|
||||
|
||||
async def _broadcast_output(self, line: str) -> None:
|
||||
"""Broadcast output line to all registered callbacks."""
|
||||
with self._callbacks_lock:
|
||||
callbacks = list(self._output_callbacks)
|
||||
|
||||
for callback in callbacks:
|
||||
await self._safe_callback(callback, line)
|
||||
|
||||
async def _stream_output(self) -> None:
|
||||
"""Stream process output to callbacks."""
|
||||
if not self.process or not self.process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
# Use run_in_executor for blocking readline
|
||||
line = await loop.run_in_executor(
|
||||
None, self.process.stdout.readline
|
||||
)
|
||||
if not line:
|
||||
break
|
||||
|
||||
decoded = line.decode("utf-8", errors="replace").rstrip()
|
||||
sanitized = sanitize_output(decoded)
|
||||
|
||||
await self._broadcast_output(sanitized)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Output streaming error: {e}")
|
||||
finally:
|
||||
# Check if process ended
|
||||
if self.process and self.process.poll() is not None:
|
||||
exit_code = self.process.returncode
|
||||
if exit_code != 0 and self.status == "running":
|
||||
self.status = "crashed"
|
||||
elif self.status == "running":
|
||||
self.status = "stopped"
|
||||
self._remove_lock()
|
||||
|
||||
async def start(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Start the agent as a subprocess.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
if self.status in ("running", "paused"):
|
||||
return False, f"Agent is already {self.status}"
|
||||
|
||||
if not self._check_lock():
|
||||
return False, "Another agent instance is already running for this project"
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(self.root_dir / "autonomous_agent_demo.py"),
|
||||
"--project-dir",
|
||||
self.project_name,
|
||||
]
|
||||
|
||||
try:
|
||||
# Start subprocess with piped stdout/stderr
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(self.root_dir),
|
||||
)
|
||||
|
||||
self._create_lock()
|
||||
self.started_at = datetime.now()
|
||||
self.status = "running"
|
||||
|
||||
# Start output streaming task
|
||||
self._output_task = asyncio.create_task(self._stream_output())
|
||||
|
||||
return True, f"Agent started with PID {self.process.pid}"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start agent")
|
||||
return False, f"Failed to start agent: {e}"
|
||||
|
||||
async def stop(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Stop the agent (SIGTERM then SIGKILL if needed).
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
if not self.process or self.status == "stopped":
|
||||
return False, "Agent is not running"
|
||||
|
||||
try:
|
||||
# Cancel output streaming
|
||||
if self._output_task:
|
||||
self._output_task.cancel()
|
||||
try:
|
||||
await self._output_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Terminate gracefully first
|
||||
self.process.terminate()
|
||||
|
||||
# Wait up to 5 seconds for graceful shutdown
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self.process.wait),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Force kill if still running
|
||||
self.process.kill()
|
||||
await loop.run_in_executor(None, self.process.wait)
|
||||
|
||||
self._remove_lock()
|
||||
self.status = "stopped"
|
||||
self.process = None
|
||||
self.started_at = None
|
||||
|
||||
return True, "Agent stopped"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to stop agent")
|
||||
return False, f"Failed to stop agent: {e}"
|
||||
|
||||
async def pause(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Pause the agent using psutil for cross-platform support.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
if not self.process or self.status != "running":
|
||||
return False, "Agent is not running"
|
||||
|
||||
try:
|
||||
proc = psutil.Process(self.process.pid)
|
||||
proc.suspend()
|
||||
self.status = "paused"
|
||||
return True, "Agent paused"
|
||||
except psutil.NoSuchProcess:
|
||||
self.status = "crashed"
|
||||
self._remove_lock()
|
||||
return False, "Agent process no longer exists"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to pause agent")
|
||||
return False, f"Failed to pause agent: {e}"
|
||||
|
||||
async def resume(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Resume a paused agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
if not self.process or self.status != "paused":
|
||||
return False, "Agent is not paused"
|
||||
|
||||
try:
|
||||
proc = psutil.Process(self.process.pid)
|
||||
proc.resume()
|
||||
self.status = "running"
|
||||
return True, "Agent resumed"
|
||||
except psutil.NoSuchProcess:
|
||||
self.status = "crashed"
|
||||
self._remove_lock()
|
||||
return False, "Agent process no longer exists"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to resume agent")
|
||||
return False, f"Failed to resume agent: {e}"
|
||||
|
||||
async def healthcheck(self) -> bool:
|
||||
"""
|
||||
Check if the agent process is still alive.
|
||||
|
||||
Updates status to 'crashed' if process has died unexpectedly.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not self.process:
|
||||
return self.status == "stopped"
|
||||
|
||||
poll = self.process.poll()
|
||||
if poll is not None:
|
||||
# Process has terminated
|
||||
if self.status in ("running", "paused"):
|
||||
self.status = "crashed"
|
||||
self._remove_lock()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_status_dict(self) -> dict:
|
||||
"""Get current status as a dictionary."""
|
||||
return {
|
||||
"status": self.status,
|
||||
"pid": self.pid,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
}
|
||||
|
||||
|
||||
# Global registry of process managers per project with thread safety
|
||||
_managers: dict[str, AgentProcessManager] = {}
|
||||
_managers_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_manager(project_name: str, root_dir: Path) -> AgentProcessManager:
|
||||
"""Get or create a process manager for a project (thread-safe)."""
|
||||
with _managers_lock:
|
||||
if project_name not in _managers:
|
||||
_managers[project_name] = AgentProcessManager(project_name, root_dir)
|
||||
return _managers[project_name]
|
||||
|
||||
|
||||
async def cleanup_all_managers() -> None:
|
||||
"""Stop all running agents. Called on server shutdown."""
|
||||
with _managers_lock:
|
||||
managers = list(_managers.values())
|
||||
|
||||
for manager in managers:
|
||||
try:
|
||||
if manager.status != "stopped":
|
||||
await manager.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping manager for {manager.project_name}: {e}")
|
||||
|
||||
with _managers_lock:
|
||||
_managers.clear()
|
||||
248
server/websocket.py
Normal file
248
server/websocket.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
WebSocket Handlers
|
||||
==================
|
||||
|
||||
Real-time updates for project progress and agent output.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from .services.process_manager import get_manager
|
||||
|
||||
# Lazy imports
|
||||
_GENERATIONS_DIR = None
|
||||
_count_passing_tests = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_generations_dir():
|
||||
"""Lazy import of GENERATIONS_DIR."""
|
||||
global _GENERATIONS_DIR
|
||||
if _GENERATIONS_DIR is None:
|
||||
import sys
|
||||
root = Path(__file__).parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from start import GENERATIONS_DIR
|
||||
_GENERATIONS_DIR = GENERATIONS_DIR
|
||||
return _GENERATIONS_DIR
|
||||
|
||||
|
||||
def _get_count_passing_tests():
|
||||
"""Lazy import of count_passing_tests."""
|
||||
global _count_passing_tests
|
||||
if _count_passing_tests is None:
|
||||
import sys
|
||||
root = Path(__file__).parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from progress import count_passing_tests
|
||||
_count_passing_tests = count_passing_tests
|
||||
return _count_passing_tests
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections per project."""
|
||||
|
||||
def __init__(self):
|
||||
# project_name -> set of WebSocket connections
|
||||
self.active_connections: dict[str, Set[WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, project_name: str):
|
||||
"""Accept a WebSocket connection for a project."""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if project_name not in self.active_connections:
|
||||
self.active_connections[project_name] = set()
|
||||
self.active_connections[project_name].add(websocket)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, project_name: str):
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
if project_name in self.active_connections:
|
||||
self.active_connections[project_name].discard(websocket)
|
||||
if not self.active_connections[project_name]:
|
||||
del self.active_connections[project_name]
|
||||
|
||||
async def broadcast_to_project(self, project_name: str, message: dict):
|
||||
"""Broadcast a message to all connections for a project."""
|
||||
async with self._lock:
|
||||
connections = list(self.active_connections.get(project_name, set()))
|
||||
|
||||
dead_connections = []
|
||||
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Clean up dead connections
|
||||
if dead_connections:
|
||||
async with self._lock:
|
||||
for connection in dead_connections:
|
||||
if project_name in self.active_connections:
|
||||
self.active_connections[project_name].discard(connection)
|
||||
|
||||
def get_connection_count(self, project_name: str) -> int:
|
||||
"""Get number of active connections for a project."""
|
||||
return len(self.active_connections.get(project_name, set()))
|
||||
|
||||
|
||||
# Global connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Root directory
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> bool:
|
||||
"""Validate project name to prevent path traversal."""
|
||||
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
|
||||
|
||||
|
||||
async def poll_progress(websocket: WebSocket, project_name: str):
|
||||
"""Poll database for progress changes and send updates."""
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
count_passing_tests = _get_count_passing_tests()
|
||||
last_passing = -1
|
||||
last_total = -1
|
||||
|
||||
while True:
|
||||
try:
|
||||
passing, total = count_passing_tests(project_dir)
|
||||
|
||||
# Only send if changed
|
||||
if passing != last_passing or total != last_total:
|
||||
last_passing = passing
|
||||
last_total = total
|
||||
percentage = (passing / total * 100) if total > 0 else 0
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "progress",
|
||||
"passing": passing,
|
||||
"total": total,
|
||||
"percentage": round(percentage, 1),
|
||||
})
|
||||
|
||||
await asyncio.sleep(2) # Poll every 2 seconds
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Progress polling error: {e}")
|
||||
break
|
||||
|
||||
|
||||
async def project_websocket(websocket: WebSocket, project_name: str):
|
||||
"""
|
||||
WebSocket endpoint for project updates.
|
||||
|
||||
Streams:
|
||||
- Progress updates (passing/total counts)
|
||||
- Agent status changes
|
||||
- Agent stdout/stderr lines
|
||||
"""
|
||||
if not validate_project_name(project_name):
|
||||
await websocket.close(code=4000, reason="Invalid project name")
|
||||
return
|
||||
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
if not project_dir.exists():
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
return
|
||||
|
||||
await manager.connect(websocket, project_name)
|
||||
|
||||
# Get agent manager and register callbacks
|
||||
agent_manager = get_manager(project_name, ROOT_DIR)
|
||||
|
||||
async def on_output(line: str):
|
||||
"""Handle agent output - broadcast to this WebSocket."""
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"line": line,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
except Exception:
|
||||
pass # Connection may be closed
|
||||
|
||||
async def on_status_change(status: str):
|
||||
"""Handle status change - broadcast to this WebSocket."""
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "agent_status",
|
||||
"status": status,
|
||||
})
|
||||
except Exception:
|
||||
pass # Connection may be closed
|
||||
|
||||
# Register callbacks
|
||||
agent_manager.add_output_callback(on_output)
|
||||
agent_manager.add_status_callback(on_status_change)
|
||||
|
||||
# Start progress polling task
|
||||
poll_task = asyncio.create_task(poll_progress(websocket, project_name))
|
||||
|
||||
try:
|
||||
# Send initial status
|
||||
await websocket.send_json({
|
||||
"type": "agent_status",
|
||||
"status": agent_manager.status,
|
||||
})
|
||||
|
||||
# Send initial progress
|
||||
count_passing_tests = _get_count_passing_tests()
|
||||
passing, total = count_passing_tests(project_dir)
|
||||
percentage = (passing / total * 100) if total > 0 else 0
|
||||
await websocket.send_json({
|
||||
"type": "progress",
|
||||
"passing": passing,
|
||||
"total": total,
|
||||
"percentage": round(percentage, 1),
|
||||
})
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for any incoming messages (ping/pong, commands, etc.)
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
# Handle ping
|
||||
if message.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from WebSocket: {data[:100] if data else 'empty'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket error: {e}")
|
||||
break
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
poll_task.cancel()
|
||||
try:
|
||||
await poll_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Unregister callbacks
|
||||
agent_manager.remove_output_callback(on_output)
|
||||
agent_manager.remove_status_callback(on_status_change)
|
||||
|
||||
# Disconnect from manager
|
||||
await manager.disconnect(websocket, project_name)
|
||||
Reference in New Issue
Block a user