diff --git a/server/routers/expand_project.py b/server/routers/expand_project.py index 0d806d8..d894719 100644 --- a/server/routers/expand_project.py +++ b/server/routers/expand_project.py @@ -8,7 +8,6 @@ Allows adding multiple features to existing projects via natural language. import json import logging -import re from pathlib import Path from typing import Optional @@ -23,6 +22,7 @@ from ..services.expand_chat_session import ( list_expand_sessions, remove_expand_session, ) +from ..utils.validation import validate_project_name logger = logging.getLogger(__name__) @@ -43,9 +43,6 @@ def _get_project_path(project_name: str) -> Path: return get_project_path(project_name) -def validate_project_name(name: str) -> bool: - """Validate project name to prevent path traversal.""" - return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name)) # ============================================================================ @@ -70,8 +67,7 @@ async def list_expand_sessions_endpoint(): @router.get("/sessions/{project_name}", response_model=ExpandSessionStatus) async def get_expand_session_status(project_name: str): """Get status of an expansion session.""" - if not validate_project_name(project_name): - raise HTTPException(status_code=400, detail="Invalid project name") + project_name = validate_project_name(project_name) session = get_expand_session(project_name) if not session: @@ -89,8 +85,7 @@ async def get_expand_session_status(project_name: str): @router.delete("/sessions/{project_name}") async def cancel_expand_session(project_name: str): """Cancel and remove an expansion session.""" - if not validate_project_name(project_name): - raise HTTPException(status_code=400, detail="Invalid project name") + project_name = validate_project_name(project_name) session = get_expand_session(project_name) if not session: @@ -124,7 +119,9 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str): - {"type": "error", "content": "..."} - Error message - {"type": "pong"} - Keep-alive pong """ - if not validate_project_name(project_name): + try: + project_name = validate_project_name(project_name) + except HTTPException: await websocket.close(code=4000, reason="Invalid project name") return diff --git a/server/routers/features.py b/server/routers/features.py index 0a5849c..ce0f388 100644 --- a/server/routers/features.py +++ b/server/routers/features.py @@ -6,7 +6,6 @@ API endpoints for feature/test case management. """ import logging -import re from contextlib import contextmanager from pathlib import Path @@ -19,6 +18,7 @@ from ..schemas import ( FeatureListResponse, FeatureResponse, ) +from ..utils.validation import validate_project_name # Lazy imports to avoid circular dependencies _create_database = None @@ -56,16 +56,6 @@ def _get_db_classes(): router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"]) -def validate_project_name(name: str) -> str: - """Validate and sanitize project name to prevent path traversal.""" - if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name): - raise HTTPException( - status_code=400, - detail="Invalid project name" - ) - return name - - @contextmanager def get_db_session(project_dir: Path): """ diff --git a/server/services/expand_chat_session.py b/server/services/expand_chat_session.py index fdd90e9..a6825f6 100644 --- a/server/services/expand_chat_session.py +++ b/server/services/expand_chat_session.py @@ -6,11 +6,13 @@ Manages interactive project expansion conversation with Claude. Uses the expand-project.md skill to help users add features to existing projects. """ +import asyncio import json import logging import re import shutil import threading +import uuid from datetime import datetime from pathlib import Path from typing import AsyncGenerator, Optional @@ -68,6 +70,7 @@ class ExpandChatSession: self.features_created: int = 0 self.created_feature_ids: list[int] = [] self._settings_file: Optional[Path] = None + self._query_lock = asyncio.Lock() async def close(self) -> None: """Clean up resources and close the Claude client.""" @@ -117,7 +120,16 @@ class ExpandChatSession: except UnicodeDecodeError: skill_content = skill_path.read_text(encoding="utf-8", errors="replace") - # Create security settings file + # Find and validate Claude CLI before creating temp files + system_cli = shutil.which("claude") + if not system_cli: + yield { + "type": "error", + "content": "Claude CLI not found. Please install Claude Code." + } + return + + # Create temporary security settings file (unique per session to avoid conflicts) security_settings = { "sandbox": {"enabled": True}, "permissions": { @@ -128,23 +140,16 @@ class ExpandChatSession: ], }, } - settings_file = self.project_dir / ".claude_settings.json" + settings_file = self.project_dir / f".claude_settings.expand.{uuid.uuid4().hex}.json" self._settings_file = settings_file - with open(settings_file, "w") as f: + with open(settings_file, "w", encoding="utf-8") as f: json.dump(security_settings, f, indent=2) # Replace $ARGUMENTS with absolute project path project_path = str(self.project_dir.resolve()) system_prompt = skill_content.replace("$ARGUMENTS", project_path) - # Find and validate Claude CLI - system_cli = shutil.which("claude") - if not system_cli: - yield { - "type": "error", - "content": "Claude CLI not found. Please install Claude Code." - } - return + # Create Claude SDK client try: self.client = ClaudeSDKClient( options=ClaudeAgentOptions( @@ -167,20 +172,21 @@ class ExpandChatSession: logger.exception("Failed to create Claude client") yield { "type": "error", - "content": f"Failed to initialize Claude: {str(e)}" + "content": "Failed to initialize Claude" } return # Start the conversation try: - async for chunk in self._query_claude("Begin the project expansion process."): - yield chunk + async with self._query_lock: + async for chunk in self._query_claude("Begin the project expansion process."): + yield chunk yield {"type": "response_done"} except Exception as e: logger.exception("Failed to start expand chat") yield { "type": "error", - "content": f"Failed to start conversation: {str(e)}" + "content": "Failed to start conversation" } async def send_message( @@ -218,14 +224,16 @@ class ExpandChatSession: }) try: - async for chunk in self._query_claude(user_message, attachments): - yield chunk + # Use lock to prevent concurrent queries from corrupting the response stream + async with self._query_lock: + async for chunk in self._query_claude(user_message, attachments): + yield chunk yield {"type": "response_done"} except Exception as e: logger.exception("Error during Claude query") yield { "type": "error", - "content": f"Error: {str(e)}" + "content": "Error while processing message" } async def _query_claude( @@ -340,6 +348,10 @@ class ExpandChatSession: Returns: List of created feature dictionaries with IDs + + Note: + Uses flush() to get IDs immediately without re-querying by priority range, + which could pick up rows from concurrent writers. """ # Import database classes import sys @@ -358,7 +370,7 @@ class ExpandChatSession: max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first() current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1 - created_features = [] + created_rows: list = [] for f in features: db_feature = Feature( @@ -370,24 +382,28 @@ class ExpandChatSession: passes=False, ) session.add(db_feature) + created_rows.append(db_feature) current_priority += 1 - session.commit() + # Flush to get IDs without relying on priority range query + session.flush() - # Re-query to get the created features with IDs - start_priority = current_priority - len(features) - for db_feature in session.query(Feature).filter( - Feature.priority >= start_priority, - Feature.priority < current_priority - ).order_by(Feature.priority).all(): - created_features.append({ + # Build result from the flushed objects (IDs are now populated) + created_features = [ + { "id": db_feature.id, "name": db_feature.name, "category": db_feature.category, - }) + } + for db_feature in created_rows + ] + session.commit() return created_features + except Exception: + session.rollback() + raise finally: session.close() diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 0000000..8ed4d66 --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +# Server utilities diff --git a/server/utils/validation.py b/server/utils/validation.py new file mode 100644 index 0000000..9f1bf11 --- /dev/null +++ b/server/utils/validation.py @@ -0,0 +1,28 @@ +""" +Shared validation utilities for the server. +""" + +import re + +from fastapi import HTTPException + + +def validate_project_name(name: str) -> str: + """ + Validate and sanitize project name to prevent path traversal. + + Args: + name: Project name to validate + + Returns: + The validated project name + + Raises: + HTTPException: If name is invalid + """ + if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name): + raise HTTPException( + status_code=400, + detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)." + ) + return name diff --git a/ui/src/components/ExpandProjectChat.tsx b/ui/src/components/ExpandProjectChat.tsx index 1484933..1077a6d 100644 --- a/ui/src/components/ExpandProjectChat.tsx +++ b/ui/src/components/ExpandProjectChat.tsx @@ -34,6 +34,9 @@ export function ExpandProjectChat({ const inputRef = useRef(null) const fileInputRef = useRef(null) + // Memoize error handler to keep hook dependencies stable + const handleError = useCallback((err: string) => setError(err), []) + const { messages, isLoading, @@ -46,7 +49,7 @@ export function ExpandProjectChat({ } = useExpandChat({ projectName, onComplete, - onError: (err) => setError(err), + onError: handleError, }) // Start the chat session when component mounts diff --git a/ui/src/hooks/useExpandChat.ts b/ui/src/hooks/useExpandChat.ts index 6a7e73e..9150885 100644 --- a/ui/src/hooks/useExpandChat.ts +++ b/ui/src/hooks/useExpandChat.ts @@ -54,6 +54,7 @@ export function useExpandChat({ const pingIntervalRef = useRef(null) const reconnectTimeoutRef = useRef(null) const isCompleteRef = useRef(false) + const manuallyDisconnectedRef = useRef(false) // Keep isCompleteRef in sync with isComplete state useEffect(() => { @@ -76,6 +77,10 @@ export function useExpandChat({ }, []) const connect = useCallback(() => { + // Don't reconnect if manually disconnected + if (manuallyDisconnectedRef.current) { + return + } if (wsRef.current?.readyState === WebSocket.OPEN) { return } @@ -92,6 +97,7 @@ export function useExpandChat({ ws.onopen = () => { setConnectionStatus('connected') reconnectAttempts.current = 0 + manuallyDisconnectedRef.current = false // Start ping interval to keep connection alive pingIntervalRef.current = window.setInterval(() => { @@ -109,7 +115,11 @@ export function useExpandChat({ } // Attempt reconnection if not intentionally closed - if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) { + if ( + !manuallyDisconnectedRef.current && + reconnectAttempts.current < maxReconnectAttempts && + !isCompleteRef.current + ) { reconnectAttempts.current++ const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000) reconnectTimeoutRef.current = window.setTimeout(connect, delay) @@ -244,18 +254,25 @@ export function useExpandChat({ const start = useCallback(() => { connect() - // Wait for connection then send start message + // Wait for connection then send start message (with timeout to prevent infinite loop) + let attempts = 0 + const maxAttempts = 50 // 5 seconds max (50 * 100ms) const checkAndSend = () => { if (wsRef.current?.readyState === WebSocket.OPEN) { setIsLoading(true) wsRef.current.send(JSON.stringify({ type: 'start' })) } else if (wsRef.current?.readyState === WebSocket.CONNECTING) { - setTimeout(checkAndSend, 100) + if (attempts++ < maxAttempts) { + setTimeout(checkAndSend, 100) + } else { + onError?.('Connection timeout') + setIsLoading(false) + } } } setTimeout(checkAndSend, 100) - }, [connect]) + }, [connect, onError]) const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => { if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { @@ -297,6 +314,7 @@ export function useExpandChat({ }, [onError]) const disconnect = useCallback(() => { + manuallyDisconnectedRef.current = true reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection if (pingIntervalRef.current) { clearInterval(pingIntervalRef.current)