mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-30 06:12:06 +00:00
basic ui
This commit is contained in:
248
server/websocket.py
Normal file
248
server/websocket.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
WebSocket Handlers
|
||||
==================
|
||||
|
||||
Real-time updates for project progress and agent output.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from .services.process_manager import get_manager
|
||||
|
||||
# Lazy imports
|
||||
_GENERATIONS_DIR = None
|
||||
_count_passing_tests = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_generations_dir():
|
||||
"""Lazy import of GENERATIONS_DIR."""
|
||||
global _GENERATIONS_DIR
|
||||
if _GENERATIONS_DIR is None:
|
||||
import sys
|
||||
root = Path(__file__).parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from start import GENERATIONS_DIR
|
||||
_GENERATIONS_DIR = GENERATIONS_DIR
|
||||
return _GENERATIONS_DIR
|
||||
|
||||
|
||||
def _get_count_passing_tests():
|
||||
"""Lazy import of count_passing_tests."""
|
||||
global _count_passing_tests
|
||||
if _count_passing_tests is None:
|
||||
import sys
|
||||
root = Path(__file__).parent.parent
|
||||
if str(root) not in sys.path:
|
||||
sys.path.insert(0, str(root))
|
||||
from progress import count_passing_tests
|
||||
_count_passing_tests = count_passing_tests
|
||||
return _count_passing_tests
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections per project."""
|
||||
|
||||
def __init__(self):
|
||||
# project_name -> set of WebSocket connections
|
||||
self.active_connections: dict[str, Set[WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, project_name: str):
|
||||
"""Accept a WebSocket connection for a project."""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if project_name not in self.active_connections:
|
||||
self.active_connections[project_name] = set()
|
||||
self.active_connections[project_name].add(websocket)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, project_name: str):
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
if project_name in self.active_connections:
|
||||
self.active_connections[project_name].discard(websocket)
|
||||
if not self.active_connections[project_name]:
|
||||
del self.active_connections[project_name]
|
||||
|
||||
async def broadcast_to_project(self, project_name: str, message: dict):
|
||||
"""Broadcast a message to all connections for a project."""
|
||||
async with self._lock:
|
||||
connections = list(self.active_connections.get(project_name, set()))
|
||||
|
||||
dead_connections = []
|
||||
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Clean up dead connections
|
||||
if dead_connections:
|
||||
async with self._lock:
|
||||
for connection in dead_connections:
|
||||
if project_name in self.active_connections:
|
||||
self.active_connections[project_name].discard(connection)
|
||||
|
||||
def get_connection_count(self, project_name: str) -> int:
|
||||
"""Get number of active connections for a project."""
|
||||
return len(self.active_connections.get(project_name, set()))
|
||||
|
||||
|
||||
# Global connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Root directory
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> bool:
|
||||
"""Validate project name to prevent path traversal."""
|
||||
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
|
||||
|
||||
|
||||
async def poll_progress(websocket: WebSocket, project_name: str):
|
||||
"""Poll database for progress changes and send updates."""
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
count_passing_tests = _get_count_passing_tests()
|
||||
last_passing = -1
|
||||
last_total = -1
|
||||
|
||||
while True:
|
||||
try:
|
||||
passing, total = count_passing_tests(project_dir)
|
||||
|
||||
# Only send if changed
|
||||
if passing != last_passing or total != last_total:
|
||||
last_passing = passing
|
||||
last_total = total
|
||||
percentage = (passing / total * 100) if total > 0 else 0
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "progress",
|
||||
"passing": passing,
|
||||
"total": total,
|
||||
"percentage": round(percentage, 1),
|
||||
})
|
||||
|
||||
await asyncio.sleep(2) # Poll every 2 seconds
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Progress polling error: {e}")
|
||||
break
|
||||
|
||||
|
||||
async def project_websocket(websocket: WebSocket, project_name: str):
|
||||
"""
|
||||
WebSocket endpoint for project updates.
|
||||
|
||||
Streams:
|
||||
- Progress updates (passing/total counts)
|
||||
- Agent status changes
|
||||
- Agent stdout/stderr lines
|
||||
"""
|
||||
if not validate_project_name(project_name):
|
||||
await websocket.close(code=4000, reason="Invalid project name")
|
||||
return
|
||||
|
||||
project_dir = _get_generations_dir() / project_name
|
||||
if not project_dir.exists():
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
return
|
||||
|
||||
await manager.connect(websocket, project_name)
|
||||
|
||||
# Get agent manager and register callbacks
|
||||
agent_manager = get_manager(project_name, ROOT_DIR)
|
||||
|
||||
async def on_output(line: str):
|
||||
"""Handle agent output - broadcast to this WebSocket."""
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"line": line,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
except Exception:
|
||||
pass # Connection may be closed
|
||||
|
||||
async def on_status_change(status: str):
|
||||
"""Handle status change - broadcast to this WebSocket."""
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "agent_status",
|
||||
"status": status,
|
||||
})
|
||||
except Exception:
|
||||
pass # Connection may be closed
|
||||
|
||||
# Register callbacks
|
||||
agent_manager.add_output_callback(on_output)
|
||||
agent_manager.add_status_callback(on_status_change)
|
||||
|
||||
# Start progress polling task
|
||||
poll_task = asyncio.create_task(poll_progress(websocket, project_name))
|
||||
|
||||
try:
|
||||
# Send initial status
|
||||
await websocket.send_json({
|
||||
"type": "agent_status",
|
||||
"status": agent_manager.status,
|
||||
})
|
||||
|
||||
# Send initial progress
|
||||
count_passing_tests = _get_count_passing_tests()
|
||||
passing, total = count_passing_tests(project_dir)
|
||||
percentage = (passing / total * 100) if total > 0 else 0
|
||||
await websocket.send_json({
|
||||
"type": "progress",
|
||||
"passing": passing,
|
||||
"total": total,
|
||||
"percentage": round(percentage, 1),
|
||||
})
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for any incoming messages (ping/pong, commands, etc.)
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
# Handle ping
|
||||
if message.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from WebSocket: {data[:100] if data else 'empty'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket error: {e}")
|
||||
break
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
poll_task.cancel()
|
||||
try:
|
||||
await poll_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Unregister callbacks
|
||||
agent_manager.remove_output_callback(on_output)
|
||||
agent_manager.remove_status_callback(on_status_change)
|
||||
|
||||
# Disconnect from manager
|
||||
await manager.disconnect(websocket, project_name)
|
||||
Reference in New Issue
Block a user