fix: address second round of code review feedback

Backend improvements:
- Create shared validation utility for project name validation
- Add asyncio.Lock to prevent concurrent _query_claude calls
- Fix _create_features_bulk: use flush() for IDs, add rollback on error
- Use unique temp settings file instead of overwriting .claude_settings.json
- Remove exception details from error messages (security)

Frontend improvements:
- Memoize onError callback in ExpandProjectChat for stable dependencies
- Add timeout to start() checkAndSend loop to prevent infinite retries
- Add manuallyDisconnectedRef to prevent reconnection after explicit disconnect
- Clear pending reconnect timeout in disconnect()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dan Gentry
2026-01-09 23:57:50 -05:00
parent 75f2bf2a10
commit cdcbd11272
7 changed files with 106 additions and 53 deletions

View File

@@ -8,7 +8,6 @@ Allows adding multiple features to existing projects via natural language.
import json import json
import logging import logging
import re
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -23,6 +22,7 @@ from ..services.expand_chat_session import (
list_expand_sessions, list_expand_sessions,
remove_expand_session, remove_expand_session,
) )
from ..utils.validation import validate_project_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,9 +43,6 @@ def _get_project_path(project_name: str) -> Path:
return get_project_path(project_name) return get_project_path(project_name)
def validate_project_name(name: str) -> bool:
"""Validate project name to prevent path traversal."""
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
# ============================================================================ # ============================================================================
@@ -70,8 +67,7 @@ async def list_expand_sessions_endpoint():
@router.get("/sessions/{project_name}", response_model=ExpandSessionStatus) @router.get("/sessions/{project_name}", response_model=ExpandSessionStatus)
async def get_expand_session_status(project_name: str): async def get_expand_session_status(project_name: str):
"""Get status of an expansion session.""" """Get status of an expansion session."""
if not validate_project_name(project_name): project_name = validate_project_name(project_name)
raise HTTPException(status_code=400, detail="Invalid project name")
session = get_expand_session(project_name) session = get_expand_session(project_name)
if not session: if not session:
@@ -89,8 +85,7 @@ async def get_expand_session_status(project_name: str):
@router.delete("/sessions/{project_name}") @router.delete("/sessions/{project_name}")
async def cancel_expand_session(project_name: str): async def cancel_expand_session(project_name: str):
"""Cancel and remove an expansion session.""" """Cancel and remove an expansion session."""
if not validate_project_name(project_name): project_name = validate_project_name(project_name)
raise HTTPException(status_code=400, detail="Invalid project name")
session = get_expand_session(project_name) session = get_expand_session(project_name)
if not session: if not session:
@@ -124,7 +119,9 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str):
- {"type": "error", "content": "..."} - Error message - {"type": "error", "content": "..."} - Error message
- {"type": "pong"} - Keep-alive pong - {"type": "pong"} - Keep-alive pong
""" """
if not validate_project_name(project_name): try:
project_name = validate_project_name(project_name)
except HTTPException:
await websocket.close(code=4000, reason="Invalid project name") await websocket.close(code=4000, reason="Invalid project name")
return return

View File

@@ -6,7 +6,6 @@ API endpoints for feature/test case management.
""" """
import logging import logging
import re
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
@@ -19,6 +18,7 @@ from ..schemas import (
FeatureListResponse, FeatureListResponse,
FeatureResponse, FeatureResponse,
) )
from ..utils.validation import validate_project_name
# Lazy imports to avoid circular dependencies # Lazy imports to avoid circular dependencies
_create_database = None _create_database = None
@@ -56,16 +56,6 @@ def _get_db_classes():
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"]) router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
def validate_project_name(name: str) -> str:
"""Validate and sanitize project name to prevent path traversal."""
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
raise HTTPException(
status_code=400,
detail="Invalid project name"
)
return name
@contextmanager @contextmanager
def get_db_session(project_dir: Path): def get_db_session(project_dir: Path):
""" """

View File

@@ -6,11 +6,13 @@ Manages interactive project expansion conversation with Claude.
Uses the expand-project.md skill to help users add features to existing projects. Uses the expand-project.md skill to help users add features to existing projects.
""" """
import asyncio
import json import json
import logging import logging
import re import re
import shutil import shutil
import threading import threading
import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
@@ -68,6 +70,7 @@ class ExpandChatSession:
self.features_created: int = 0 self.features_created: int = 0
self.created_feature_ids: list[int] = [] self.created_feature_ids: list[int] = []
self._settings_file: Optional[Path] = None self._settings_file: Optional[Path] = None
self._query_lock = asyncio.Lock()
async def close(self) -> None: async def close(self) -> None:
"""Clean up resources and close the Claude client.""" """Clean up resources and close the Claude client."""
@@ -117,7 +120,16 @@ class ExpandChatSession:
except UnicodeDecodeError: except UnicodeDecodeError:
skill_content = skill_path.read_text(encoding="utf-8", errors="replace") skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
# Create security settings file # Find and validate Claude CLI before creating temp files
system_cli = shutil.which("claude")
if not system_cli:
yield {
"type": "error",
"content": "Claude CLI not found. Please install Claude Code."
}
return
# Create temporary security settings file (unique per session to avoid conflicts)
security_settings = { security_settings = {
"sandbox": {"enabled": True}, "sandbox": {"enabled": True},
"permissions": { "permissions": {
@@ -128,23 +140,16 @@ class ExpandChatSession:
], ],
}, },
} }
settings_file = self.project_dir / ".claude_settings.json" settings_file = self.project_dir / f".claude_settings.expand.{uuid.uuid4().hex}.json"
self._settings_file = settings_file self._settings_file = settings_file
with open(settings_file, "w") as f: with open(settings_file, "w", encoding="utf-8") as f:
json.dump(security_settings, f, indent=2) json.dump(security_settings, f, indent=2)
# Replace $ARGUMENTS with absolute project path # Replace $ARGUMENTS with absolute project path
project_path = str(self.project_dir.resolve()) project_path = str(self.project_dir.resolve())
system_prompt = skill_content.replace("$ARGUMENTS", project_path) system_prompt = skill_content.replace("$ARGUMENTS", project_path)
# Find and validate Claude CLI # Create Claude SDK client
system_cli = shutil.which("claude")
if not system_cli:
yield {
"type": "error",
"content": "Claude CLI not found. Please install Claude Code."
}
return
try: try:
self.client = ClaudeSDKClient( self.client = ClaudeSDKClient(
options=ClaudeAgentOptions( options=ClaudeAgentOptions(
@@ -167,20 +172,21 @@ class ExpandChatSession:
logger.exception("Failed to create Claude client") logger.exception("Failed to create Claude client")
yield { yield {
"type": "error", "type": "error",
"content": f"Failed to initialize Claude: {str(e)}" "content": "Failed to initialize Claude"
} }
return return
# Start the conversation # Start the conversation
try: try:
async for chunk in self._query_claude("Begin the project expansion process."): async with self._query_lock:
yield chunk async for chunk in self._query_claude("Begin the project expansion process."):
yield chunk
yield {"type": "response_done"} yield {"type": "response_done"}
except Exception as e: except Exception as e:
logger.exception("Failed to start expand chat") logger.exception("Failed to start expand chat")
yield { yield {
"type": "error", "type": "error",
"content": f"Failed to start conversation: {str(e)}" "content": "Failed to start conversation"
} }
async def send_message( async def send_message(
@@ -218,14 +224,16 @@ class ExpandChatSession:
}) })
try: try:
async for chunk in self._query_claude(user_message, attachments): # Use lock to prevent concurrent queries from corrupting the response stream
yield chunk async with self._query_lock:
async for chunk in self._query_claude(user_message, attachments):
yield chunk
yield {"type": "response_done"} yield {"type": "response_done"}
except Exception as e: except Exception as e:
logger.exception("Error during Claude query") logger.exception("Error during Claude query")
yield { yield {
"type": "error", "type": "error",
"content": f"Error: {str(e)}" "content": "Error while processing message"
} }
async def _query_claude( async def _query_claude(
@@ -340,6 +348,10 @@ class ExpandChatSession:
Returns: Returns:
List of created feature dictionaries with IDs List of created feature dictionaries with IDs
Note:
Uses flush() to get IDs immediately without re-querying by priority range,
which could pick up rows from concurrent writers.
""" """
# Import database classes # Import database classes
import sys import sys
@@ -358,7 +370,7 @@ class ExpandChatSession:
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first() max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1 current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
created_features = [] created_rows: list = []
for f in features: for f in features:
db_feature = Feature( db_feature = Feature(
@@ -370,24 +382,28 @@ class ExpandChatSession:
passes=False, passes=False,
) )
session.add(db_feature) session.add(db_feature)
created_rows.append(db_feature)
current_priority += 1 current_priority += 1
session.commit() # Flush to get IDs without relying on priority range query
session.flush()
# Re-query to get the created features with IDs # Build result from the flushed objects (IDs are now populated)
start_priority = current_priority - len(features) created_features = [
for db_feature in session.query(Feature).filter( {
Feature.priority >= start_priority,
Feature.priority < current_priority
).order_by(Feature.priority).all():
created_features.append({
"id": db_feature.id, "id": db_feature.id,
"name": db_feature.name, "name": db_feature.name,
"category": db_feature.category, "category": db_feature.category,
}) }
for db_feature in created_rows
]
session.commit()
return created_features return created_features
except Exception:
session.rollback()
raise
finally: finally:
session.close() session.close()

1
server/utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Server utilities

View File

@@ -0,0 +1,28 @@
"""
Shared validation utilities for the server.
"""
import re
from fastapi import HTTPException
def validate_project_name(name: str) -> str:
"""
Validate and sanitize project name to prevent path traversal.
Args:
name: Project name to validate
Returns:
The validated project name
Raises:
HTTPException: If name is invalid
"""
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
raise HTTPException(
status_code=400,
detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)."
)
return name

View File

@@ -34,6 +34,9 @@ export function ExpandProjectChat({
const inputRef = useRef<HTMLInputElement>(null) const inputRef = useRef<HTMLInputElement>(null)
const fileInputRef = useRef<HTMLInputElement>(null) const fileInputRef = useRef<HTMLInputElement>(null)
// Memoize error handler to keep hook dependencies stable
const handleError = useCallback((err: string) => setError(err), [])
const { const {
messages, messages,
isLoading, isLoading,
@@ -46,7 +49,7 @@ export function ExpandProjectChat({
} = useExpandChat({ } = useExpandChat({
projectName, projectName,
onComplete, onComplete,
onError: (err) => setError(err), onError: handleError,
}) })
// Start the chat session when component mounts // Start the chat session when component mounts

View File

@@ -54,6 +54,7 @@ export function useExpandChat({
const pingIntervalRef = useRef<number | null>(null) const pingIntervalRef = useRef<number | null>(null)
const reconnectTimeoutRef = useRef<number | null>(null) const reconnectTimeoutRef = useRef<number | null>(null)
const isCompleteRef = useRef(false) const isCompleteRef = useRef(false)
const manuallyDisconnectedRef = useRef(false)
// Keep isCompleteRef in sync with isComplete state // Keep isCompleteRef in sync with isComplete state
useEffect(() => { useEffect(() => {
@@ -76,6 +77,10 @@ export function useExpandChat({
}, []) }, [])
const connect = useCallback(() => { const connect = useCallback(() => {
// Don't reconnect if manually disconnected
if (manuallyDisconnectedRef.current) {
return
}
if (wsRef.current?.readyState === WebSocket.OPEN) { if (wsRef.current?.readyState === WebSocket.OPEN) {
return return
} }
@@ -92,6 +97,7 @@ export function useExpandChat({
ws.onopen = () => { ws.onopen = () => {
setConnectionStatus('connected') setConnectionStatus('connected')
reconnectAttempts.current = 0 reconnectAttempts.current = 0
manuallyDisconnectedRef.current = false
// Start ping interval to keep connection alive // Start ping interval to keep connection alive
pingIntervalRef.current = window.setInterval(() => { pingIntervalRef.current = window.setInterval(() => {
@@ -109,7 +115,11 @@ export function useExpandChat({
} }
// Attempt reconnection if not intentionally closed // Attempt reconnection if not intentionally closed
if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) { if (
!manuallyDisconnectedRef.current &&
reconnectAttempts.current < maxReconnectAttempts &&
!isCompleteRef.current
) {
reconnectAttempts.current++ reconnectAttempts.current++
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000) const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000)
reconnectTimeoutRef.current = window.setTimeout(connect, delay) reconnectTimeoutRef.current = window.setTimeout(connect, delay)
@@ -244,18 +254,25 @@ export function useExpandChat({
const start = useCallback(() => { const start = useCallback(() => {
connect() connect()
// Wait for connection then send start message // Wait for connection then send start message (with timeout to prevent infinite loop)
let attempts = 0
const maxAttempts = 50 // 5 seconds max (50 * 100ms)
const checkAndSend = () => { const checkAndSend = () => {
if (wsRef.current?.readyState === WebSocket.OPEN) { if (wsRef.current?.readyState === WebSocket.OPEN) {
setIsLoading(true) setIsLoading(true)
wsRef.current.send(JSON.stringify({ type: 'start' })) wsRef.current.send(JSON.stringify({ type: 'start' }))
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) { } else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
setTimeout(checkAndSend, 100) if (attempts++ < maxAttempts) {
setTimeout(checkAndSend, 100)
} else {
onError?.('Connection timeout')
setIsLoading(false)
}
} }
} }
setTimeout(checkAndSend, 100) setTimeout(checkAndSend, 100)
}, [connect]) }, [connect, onError])
const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => { const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
@@ -297,6 +314,7 @@ export function useExpandChat({
}, [onError]) }, [onError])
const disconnect = useCallback(() => { const disconnect = useCallback(() => {
manuallyDisconnectedRef.current = true
reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection
if (pingIntervalRef.current) { if (pingIntervalRef.current) {
clearInterval(pingIntervalRef.current) clearInterval(pingIntervalRef.current)