""" WebSocket Handlers ================== Real-time updates for project progress and agent output. """ import asyncio import json import logging import re from datetime import datetime from pathlib import Path from typing import Set from fastapi import WebSocket, WebSocketDisconnect from .services.process_manager import get_manager # Lazy imports _GENERATIONS_DIR = None _count_passing_tests = None logger = logging.getLogger(__name__) def _get_generations_dir(): """Lazy import of GENERATIONS_DIR.""" global _GENERATIONS_DIR if _GENERATIONS_DIR is None: import sys root = Path(__file__).parent.parent if str(root) not in sys.path: sys.path.insert(0, str(root)) from start import GENERATIONS_DIR _GENERATIONS_DIR = GENERATIONS_DIR return _GENERATIONS_DIR def _get_count_passing_tests(): """Lazy import of count_passing_tests.""" global _count_passing_tests if _count_passing_tests is None: import sys root = Path(__file__).parent.parent if str(root) not in sys.path: sys.path.insert(0, str(root)) from progress import count_passing_tests _count_passing_tests = count_passing_tests return _count_passing_tests class ConnectionManager: """Manages WebSocket connections per project.""" def __init__(self): # project_name -> set of WebSocket connections self.active_connections: dict[str, Set[WebSocket]] = {} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, project_name: str): """Accept a WebSocket connection for a project.""" await websocket.accept() async with self._lock: if project_name not in self.active_connections: self.active_connections[project_name] = set() self.active_connections[project_name].add(websocket) async def disconnect(self, websocket: WebSocket, project_name: str): """Remove a WebSocket connection.""" async with self._lock: if project_name in self.active_connections: self.active_connections[project_name].discard(websocket) if not self.active_connections[project_name]: del self.active_connections[project_name] async def broadcast_to_project(self, project_name: str, message: dict): """Broadcast a message to all connections for a project.""" async with self._lock: connections = list(self.active_connections.get(project_name, set())) dead_connections = [] for connection in connections: try: await connection.send_json(message) except Exception: dead_connections.append(connection) # Clean up dead connections if dead_connections: async with self._lock: for connection in dead_connections: if project_name in self.active_connections: self.active_connections[project_name].discard(connection) def get_connection_count(self, project_name: str) -> int: """Get number of active connections for a project.""" return len(self.active_connections.get(project_name, set())) # Global connection manager manager = ConnectionManager() # Root directory ROOT_DIR = Path(__file__).parent.parent def validate_project_name(name: str) -> bool: """Validate project name to prevent path traversal.""" return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name)) async def poll_progress(websocket: WebSocket, project_name: str): """Poll database for progress changes and send updates.""" project_dir = _get_generations_dir() / project_name count_passing_tests = _get_count_passing_tests() last_passing = -1 last_total = -1 while True: try: passing, total = count_passing_tests(project_dir) # Only send if changed if passing != last_passing or total != last_total: last_passing = passing last_total = total percentage = (passing / total * 100) if total > 0 else 0 await websocket.send_json({ "type": "progress", "passing": passing, "total": total, "percentage": round(percentage, 1), }) await asyncio.sleep(2) # Poll every 2 seconds except asyncio.CancelledError: raise except Exception as e: logger.warning(f"Progress polling error: {e}") break async def project_websocket(websocket: WebSocket, project_name: str): """ WebSocket endpoint for project updates. Streams: - Progress updates (passing/total counts) - Agent status changes - Agent stdout/stderr lines """ if not validate_project_name(project_name): await websocket.close(code=4000, reason="Invalid project name") return project_dir = _get_generations_dir() / project_name if not project_dir.exists(): await websocket.close(code=4004, reason="Project not found") return await manager.connect(websocket, project_name) # Get agent manager and register callbacks agent_manager = get_manager(project_name, ROOT_DIR) async def on_output(line: str): """Handle agent output - broadcast to this WebSocket.""" try: await websocket.send_json({ "type": "log", "line": line, "timestamp": datetime.now().isoformat(), }) except Exception: pass # Connection may be closed async def on_status_change(status: str): """Handle status change - broadcast to this WebSocket.""" try: await websocket.send_json({ "type": "agent_status", "status": status, }) except Exception: pass # Connection may be closed # Register callbacks agent_manager.add_output_callback(on_output) agent_manager.add_status_callback(on_status_change) # Start progress polling task poll_task = asyncio.create_task(poll_progress(websocket, project_name)) try: # Send initial status await websocket.send_json({ "type": "agent_status", "status": agent_manager.status, }) # Send initial progress count_passing_tests = _get_count_passing_tests() passing, total = count_passing_tests(project_dir) percentage = (passing / total * 100) if total > 0 else 0 await websocket.send_json({ "type": "progress", "passing": passing, "total": total, "percentage": round(percentage, 1), }) # Keep connection alive and handle incoming messages while True: try: # Wait for any incoming messages (ping/pong, commands, etc.) data = await websocket.receive_text() message = json.loads(data) # Handle ping if message.get("type") == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: break except json.JSONDecodeError: logger.warning(f"Invalid JSON from WebSocket: {data[:100] if data else 'empty'}") except Exception as e: logger.warning(f"WebSocket error: {e}") break finally: # Clean up poll_task.cancel() try: await poll_task except asyncio.CancelledError: pass # Unregister callbacks agent_manager.remove_output_callback(on_output) agent_manager.remove_status_callback(on_status_change) # Disconnect from manager await manager.disconnect(websocket, project_name)