""" WebSocket Handlers ================== Real-time updates for project progress, agent output, and dev server output. """ import asyncio import json import logging import re from datetime import datetime from pathlib import Path from typing import Set from fastapi import WebSocket, WebSocketDisconnect from .schemas import AGENT_MASCOTS from .services.dev_server_manager import get_devserver_manager from .services.process_manager import get_manager # Lazy imports _count_passing_tests = None logger = logging.getLogger(__name__) # Pattern to extract feature ID from parallel orchestrator output (coding agents) FEATURE_ID_PATTERN = re.compile(r'\[Feature #(\d+)\]\s*(.*)') # Pattern to extract testing agent output TESTING_AGENT_PATTERN = re.compile(r'\[Testing\]\s*(.*)') # Patterns for detecting agent activity and thoughts THOUGHT_PATTERNS = [ # Claude's tool usage patterns (actual format: [Tool: name]) (re.compile(r'\[Tool:\s*Read\]', re.I), 'thinking'), (re.compile(r'\[Tool:\s*(?:Write|Edit|NotebookEdit)\]', re.I), 'working'), (re.compile(r'\[Tool:\s*Bash\]', re.I), 'testing'), (re.compile(r'\[Tool:\s*(?:Glob|Grep)\]', re.I), 'thinking'), (re.compile(r'\[Tool:\s*(\w+)\]', re.I), 'working'), # Fallback for other tools # Claude's internal thoughts (re.compile(r'(?:Reading|Analyzing|Checking|Looking at|Examining)\s+(.+)', re.I), 'thinking'), (re.compile(r'(?:Creating|Writing|Adding|Implementing|Building)\s+(.+)', re.I), 'working'), (re.compile(r'(?:Testing|Verifying|Running tests|Validating)\s+(.+)', re.I), 'testing'), (re.compile(r'(?:Error|Failed|Cannot|Unable to|Exception)\s+(.+)', re.I), 'struggling'), # Test results (re.compile(r'(?:PASS|passed|success)', re.I), 'success'), (re.compile(r'(?:FAIL|failed|error)', re.I), 'struggling'), ] class AgentTracker: """Tracks active agents and their states for multi-agent mode.""" # Use a special key for the testing agent since it doesn't have a fixed feature ID TESTING_AGENT_KEY = -1 def __init__(self): # feature_id -> {name, state, last_thought, agent_index, agent_type} # For testing agents, use TESTING_AGENT_KEY as the key self.active_agents: dict[int, dict] = {} self._next_agent_index = 0 self._lock = asyncio.Lock() async def process_line(self, line: str) -> dict | None: """ Process an output line and return an agent_update message if relevant. Returns None if no update should be emitted. """ # Check for testing agent output first testing_match = TESTING_AGENT_PATTERN.match(line) if testing_match: content = testing_match.group(1) return await self._process_testing_agent_line(content) # Check for feature-specific output (coding agents) match = FEATURE_ID_PATTERN.match(line) if not match: # Also check for orchestrator status messages if line.startswith("Started coding agent for feature #"): try: feature_id = int(re.search(r'#(\d+)', line).group(1)) return await self._handle_agent_start(feature_id, line, agent_type="coding") except (AttributeError, ValueError): pass elif line.startswith("Started testing agent"): return await self._handle_testing_agent_start(line) elif line.startswith("Feature #") and ("completed" in line or "failed" in line): try: feature_id = int(re.search(r'#(\d+)', line).group(1)) is_success = "completed" in line return await self._handle_agent_complete(feature_id, is_success) except (AttributeError, ValueError): pass elif line.startswith("Testing agent") and ("completed" in line or "failed" in line): # Format: "Testing agent (PID xxx) completed" or "Testing agent (PID xxx) failed" is_success = "completed" in line return await self._handle_testing_agent_complete(is_success) return None feature_id = int(match.group(1)) content = match.group(2) async with self._lock: # Ensure agent is tracked if feature_id not in self.active_agents: agent_index = self._next_agent_index self._next_agent_index += 1 self.active_agents[feature_id] = { 'name': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agent_index': agent_index, 'agent_type': 'coding', 'state': 'thinking', 'feature_name': f'Feature #{feature_id}', 'last_thought': None, } agent = self.active_agents[feature_id] # Detect state and thought from content state = 'working' thought = None for pattern, detected_state in THOUGHT_PATTERNS: m = pattern.search(content) if m: state = detected_state thought = m.group(1) if m.lastindex else content[:100] break # Only emit update if state changed or we have a new thought if state != agent['state'] or thought != agent['last_thought']: agent['state'] = state if thought: agent['last_thought'] = thought return { 'type': 'agent_update', 'agentIndex': agent['agent_index'], 'agentName': agent['name'], 'agentType': agent['agent_type'], 'featureId': feature_id, 'featureName': agent['feature_name'], 'state': state, 'thought': thought, 'timestamp': datetime.now().isoformat(), } return None async def _process_testing_agent_line(self, content: str) -> dict | None: """Process output from a testing agent.""" async with self._lock: # Ensure testing agent is tracked if self.TESTING_AGENT_KEY not in self.active_agents: agent_index = self._next_agent_index self._next_agent_index += 1 self.active_agents[self.TESTING_AGENT_KEY] = { 'name': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agent_index': agent_index, 'agent_type': 'testing', 'state': 'testing', 'feature_name': 'Regression Testing', 'last_thought': None, } agent = self.active_agents[self.TESTING_AGENT_KEY] # Detect state and thought from content state = 'testing' thought = None for pattern, detected_state in THOUGHT_PATTERNS: m = pattern.search(content) if m: state = detected_state thought = m.group(1) if m.lastindex else content[:100] break # Only emit update if state changed or we have a new thought if state != agent['state'] or thought != agent['last_thought']: agent['state'] = state if thought: agent['last_thought'] = thought return { 'type': 'agent_update', 'agentIndex': agent['agent_index'], 'agentName': agent['name'], 'agentType': 'testing', 'featureId': 0, # Testing agents work on random features 'featureName': agent['feature_name'], 'state': state, 'thought': thought, 'timestamp': datetime.now().isoformat(), } return None async def _handle_testing_agent_start(self, line: str) -> dict | None: """Handle testing agent start message from orchestrator.""" async with self._lock: agent_index = self._next_agent_index self._next_agent_index += 1 self.active_agents[self.TESTING_AGENT_KEY] = { 'name': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agent_index': agent_index, 'agent_type': 'testing', 'state': 'testing', 'feature_name': 'Regression Testing', 'last_thought': 'Starting regression tests...', } return { 'type': 'agent_update', 'agentIndex': agent_index, 'agentName': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agentType': 'testing', 'featureId': 0, 'featureName': 'Regression Testing', 'state': 'testing', 'thought': 'Starting regression tests...', 'timestamp': datetime.now().isoformat(), } async def _handle_testing_agent_complete(self, is_success: bool) -> dict | None: """Handle testing agent completion.""" async with self._lock: if self.TESTING_AGENT_KEY not in self.active_agents: return None agent = self.active_agents[self.TESTING_AGENT_KEY] state = 'success' if is_success else 'error' result = { 'type': 'agent_update', 'agentIndex': agent['agent_index'], 'agentName': agent['name'], 'agentType': 'testing', 'featureId': 0, 'featureName': agent['feature_name'], 'state': state, 'thought': 'Tests passed!' if is_success else 'Found regressions', 'timestamp': datetime.now().isoformat(), } # Remove from active agents del self.active_agents[self.TESTING_AGENT_KEY] return result def get_agent_info(self, feature_id: int) -> tuple[int | None, str | None]: """Get agent index and name for a feature ID. Returns: Tuple of (agentIndex, agentName) or (None, None) if not tracked. """ agent = self.active_agents.get(feature_id) if agent: return agent['agent_index'], agent['name'] return None, None async def _handle_agent_start(self, feature_id: int, line: str, agent_type: str = "coding") -> dict | None: """Handle agent start message from orchestrator.""" async with self._lock: agent_index = self._next_agent_index self._next_agent_index += 1 # Try to extract feature name from line feature_name = f'Feature #{feature_id}' name_match = re.search(r'#\d+:\s*(.+)$', line) if name_match: feature_name = name_match.group(1) self.active_agents[feature_id] = { 'name': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agent_index': agent_index, 'agent_type': agent_type, 'state': 'thinking', 'feature_name': feature_name, 'last_thought': 'Starting work...', } return { 'type': 'agent_update', 'agentIndex': agent_index, 'agentName': AGENT_MASCOTS[agent_index % len(AGENT_MASCOTS)], 'agentType': agent_type, 'featureId': feature_id, 'featureName': feature_name, 'state': 'thinking', 'thought': 'Starting work...', 'timestamp': datetime.now().isoformat(), } async def _handle_agent_complete(self, feature_id: int, is_success: bool) -> dict | None: """Handle agent completion message from orchestrator.""" async with self._lock: if feature_id not in self.active_agents: return None agent = self.active_agents[feature_id] state = 'success' if is_success else 'error' agent_type = agent.get('agent_type', 'coding') result = { 'type': 'agent_update', 'agentIndex': agent['agent_index'], 'agentName': agent['name'], 'agentType': agent_type, 'featureId': feature_id, 'featureName': agent['feature_name'], 'state': state, 'thought': 'Completed successfully!' if is_success else 'Failed to complete', 'timestamp': datetime.now().isoformat(), } # Remove from active agents del self.active_agents[feature_id] return result def _get_project_path(project_name: str) -> Path: """Get project path from registry.""" import sys root = Path(__file__).parent.parent if str(root) not in sys.path: sys.path.insert(0, str(root)) from registry import get_project_path return get_project_path(project_name) def _get_count_passing_tests(): """Lazy import of count_passing_tests.""" global _count_passing_tests if _count_passing_tests is None: import sys root = Path(__file__).parent.parent if str(root) not in sys.path: sys.path.insert(0, str(root)) from progress import count_passing_tests _count_passing_tests = count_passing_tests return _count_passing_tests class ConnectionManager: """Manages WebSocket connections per project.""" def __init__(self): # project_name -> set of WebSocket connections self.active_connections: dict[str, Set[WebSocket]] = {} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, project_name: str): """Accept a WebSocket connection for a project.""" await websocket.accept() async with self._lock: if project_name not in self.active_connections: self.active_connections[project_name] = set() self.active_connections[project_name].add(websocket) async def disconnect(self, websocket: WebSocket, project_name: str): """Remove a WebSocket connection.""" async with self._lock: if project_name in self.active_connections: self.active_connections[project_name].discard(websocket) if not self.active_connections[project_name]: del self.active_connections[project_name] async def broadcast_to_project(self, project_name: str, message: dict): """Broadcast a message to all connections for a project.""" async with self._lock: connections = list(self.active_connections.get(project_name, set())) dead_connections = [] for connection in connections: try: await connection.send_json(message) except Exception: dead_connections.append(connection) # Clean up dead connections if dead_connections: async with self._lock: for connection in dead_connections: if project_name in self.active_connections: self.active_connections[project_name].discard(connection) def get_connection_count(self, project_name: str) -> int: """Get number of active connections for a project.""" return len(self.active_connections.get(project_name, set())) # Global connection manager manager = ConnectionManager() # Root directory ROOT_DIR = Path(__file__).parent.parent def validate_project_name(name: str) -> bool: """Validate project name to prevent path traversal.""" return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name)) async def poll_progress(websocket: WebSocket, project_name: str, project_dir: Path): """Poll database for progress changes and send updates.""" count_passing_tests = _get_count_passing_tests() last_passing = -1 last_in_progress = -1 last_total = -1 while True: try: passing, in_progress, total = count_passing_tests(project_dir) # Only send if changed if passing != last_passing or in_progress != last_in_progress or total != last_total: last_passing = passing last_in_progress = in_progress last_total = total percentage = (passing / total * 100) if total > 0 else 0 await websocket.send_json({ "type": "progress", "passing": passing, "in_progress": in_progress, "total": total, "percentage": round(percentage, 1), }) await asyncio.sleep(2) # Poll every 2 seconds except asyncio.CancelledError: raise except Exception as e: logger.warning(f"Progress polling error: {e}") break async def project_websocket(websocket: WebSocket, project_name: str): """ WebSocket endpoint for project updates. Streams: - Progress updates (passing/total counts) - Agent status changes - Agent stdout/stderr lines """ if not validate_project_name(project_name): await websocket.close(code=4000, reason="Invalid project name") return project_dir = _get_project_path(project_name) if not project_dir: await websocket.close(code=4004, reason="Project not found in registry") return if not project_dir.exists(): await websocket.close(code=4004, reason="Project directory not found") return await manager.connect(websocket, project_name) # Get agent manager and register callbacks agent_manager = get_manager(project_name, project_dir, ROOT_DIR) # Create agent tracker for multi-agent mode agent_tracker = AgentTracker() async def on_output(line: str): """Handle agent output - broadcast to this WebSocket.""" try: # Extract feature ID from line if present feature_id = None agent_index = None match = FEATURE_ID_PATTERN.match(line) if match: feature_id = int(match.group(1)) agent_index, _ = agent_tracker.get_agent_info(feature_id) # Send the raw log line with optional feature/agent attribution log_msg = { "type": "log", "line": line, "timestamp": datetime.now().isoformat(), } if feature_id is not None: log_msg["featureId"] = feature_id if agent_index is not None: log_msg["agentIndex"] = agent_index await websocket.send_json(log_msg) # Check if this line indicates agent activity (parallel mode) # and emit agent_update messages if so agent_update = await agent_tracker.process_line(line) if agent_update: await websocket.send_json(agent_update) except Exception: pass # Connection may be closed async def on_status_change(status: str): """Handle status change - broadcast to this WebSocket.""" try: await websocket.send_json({ "type": "agent_status", "status": status, }) except Exception: pass # Connection may be closed # Register callbacks agent_manager.add_output_callback(on_output) agent_manager.add_status_callback(on_status_change) # Get dev server manager and register callbacks devserver_manager = get_devserver_manager(project_name, project_dir) async def on_dev_output(line: str): """Handle dev server output - broadcast to this WebSocket.""" try: await websocket.send_json({ "type": "dev_log", "line": line, "timestamp": datetime.now().isoformat(), }) except Exception: pass # Connection may be closed async def on_dev_status_change(status: str): """Handle dev server status change - broadcast to this WebSocket.""" try: await websocket.send_json({ "type": "dev_server_status", "status": status, "url": devserver_manager.detected_url, }) except Exception: pass # Connection may be closed # Register dev server callbacks devserver_manager.add_output_callback(on_dev_output) devserver_manager.add_status_callback(on_dev_status_change) # Start progress polling task poll_task = asyncio.create_task(poll_progress(websocket, project_name, project_dir)) try: # Send initial agent status await websocket.send_json({ "type": "agent_status", "status": agent_manager.status, }) # Send initial dev server status await websocket.send_json({ "type": "dev_server_status", "status": devserver_manager.status, "url": devserver_manager.detected_url, }) # Send initial progress count_passing_tests = _get_count_passing_tests() passing, in_progress, total = count_passing_tests(project_dir) percentage = (passing / total * 100) if total > 0 else 0 await websocket.send_json({ "type": "progress", "passing": passing, "in_progress": in_progress, "total": total, "percentage": round(percentage, 1), }) # Keep connection alive and handle incoming messages while True: try: # Wait for any incoming messages (ping/pong, commands, etc.) data = await websocket.receive_text() message = json.loads(data) # Handle ping if message.get("type") == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: break except json.JSONDecodeError: logger.warning(f"Invalid JSON from WebSocket: {data[:100] if data else 'empty'}") except Exception as e: logger.warning(f"WebSocket error: {e}") break finally: # Clean up poll_task.cancel() try: await poll_task except asyncio.CancelledError: pass # Unregister agent callbacks agent_manager.remove_output_callback(on_output) agent_manager.remove_status_callback(on_status_change) # Unregister dev server callbacks devserver_manager.remove_output_callback(on_dev_output) devserver_manager.remove_status_callback(on_dev_status_change) # Disconnect from manager await manager.disconnect(websocket, project_name)