fix: address second round of code review feedback

Backend improvements:
- Create shared validation utility for project name validation
- Add asyncio.Lock to prevent concurrent _query_claude calls
- Fix _create_features_bulk: use flush() for IDs, add rollback on error
- Use unique temp settings file instead of overwriting .claude_settings.json
- Remove exception details from error messages (security)

Frontend improvements:
- Memoize onError callback in ExpandProjectChat for stable dependencies
- Add timeout to start() checkAndSend loop to prevent infinite retries
- Add manuallyDisconnectedRef to prevent reconnection after explicit disconnect
- Clear pending reconnect timeout in disconnect()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dan Gentry
2026-01-09 23:57:50 -05:00
parent 75f2bf2a10
commit cdcbd11272
7 changed files with 106 additions and 53 deletions

View File

@@ -8,7 +8,6 @@ Allows adding multiple features to existing projects via natural language.
import json
import logging
import re
from pathlib import Path
from typing import Optional
@@ -23,6 +22,7 @@ from ..services.expand_chat_session import (
list_expand_sessions,
remove_expand_session,
)
from ..utils.validation import validate_project_name
logger = logging.getLogger(__name__)
@@ -43,9 +43,6 @@ def _get_project_path(project_name: str) -> Path:
return get_project_path(project_name)
def validate_project_name(name: str) -> bool:
"""Validate project name to prevent path traversal."""
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
# ============================================================================
@@ -70,8 +67,7 @@ async def list_expand_sessions_endpoint():
@router.get("/sessions/{project_name}", response_model=ExpandSessionStatus)
async def get_expand_session_status(project_name: str):
"""Get status of an expansion session."""
if not validate_project_name(project_name):
raise HTTPException(status_code=400, detail="Invalid project name")
project_name = validate_project_name(project_name)
session = get_expand_session(project_name)
if not session:
@@ -89,8 +85,7 @@ async def get_expand_session_status(project_name: str):
@router.delete("/sessions/{project_name}")
async def cancel_expand_session(project_name: str):
"""Cancel and remove an expansion session."""
if not validate_project_name(project_name):
raise HTTPException(status_code=400, detail="Invalid project name")
project_name = validate_project_name(project_name)
session = get_expand_session(project_name)
if not session:
@@ -124,7 +119,9 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str):
- {"type": "error", "content": "..."} - Error message
- {"type": "pong"} - Keep-alive pong
"""
if not validate_project_name(project_name):
try:
project_name = validate_project_name(project_name)
except HTTPException:
await websocket.close(code=4000, reason="Invalid project name")
return

View File

@@ -6,7 +6,6 @@ API endpoints for feature/test case management.
"""
import logging
import re
from contextlib import contextmanager
from pathlib import Path
@@ -19,6 +18,7 @@ from ..schemas import (
FeatureListResponse,
FeatureResponse,
)
from ..utils.validation import validate_project_name
# Lazy imports to avoid circular dependencies
_create_database = None
@@ -56,16 +56,6 @@ def _get_db_classes():
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
def validate_project_name(name: str) -> str:
"""Validate and sanitize project name to prevent path traversal."""
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
raise HTTPException(
status_code=400,
detail="Invalid project name"
)
return name
@contextmanager
def get_db_session(project_dir: Path):
"""

View File

@@ -6,11 +6,13 @@ Manages interactive project expansion conversation with Claude.
Uses the expand-project.md skill to help users add features to existing projects.
"""
import asyncio
import json
import logging
import re
import shutil
import threading
import uuid
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Optional
@@ -68,6 +70,7 @@ class ExpandChatSession:
self.features_created: int = 0
self.created_feature_ids: list[int] = []
self._settings_file: Optional[Path] = None
self._query_lock = asyncio.Lock()
async def close(self) -> None:
"""Clean up resources and close the Claude client."""
@@ -117,7 +120,16 @@ class ExpandChatSession:
except UnicodeDecodeError:
skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
# Create security settings file
# Find and validate Claude CLI before creating temp files
system_cli = shutil.which("claude")
if not system_cli:
yield {
"type": "error",
"content": "Claude CLI not found. Please install Claude Code."
}
return
# Create temporary security settings file (unique per session to avoid conflicts)
security_settings = {
"sandbox": {"enabled": True},
"permissions": {
@@ -128,23 +140,16 @@ class ExpandChatSession:
],
},
}
settings_file = self.project_dir / ".claude_settings.json"
settings_file = self.project_dir / f".claude_settings.expand.{uuid.uuid4().hex}.json"
self._settings_file = settings_file
with open(settings_file, "w") as f:
with open(settings_file, "w", encoding="utf-8") as f:
json.dump(security_settings, f, indent=2)
# Replace $ARGUMENTS with absolute project path
project_path = str(self.project_dir.resolve())
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
# Find and validate Claude CLI
system_cli = shutil.which("claude")
if not system_cli:
yield {
"type": "error",
"content": "Claude CLI not found. Please install Claude Code."
}
return
# Create Claude SDK client
try:
self.client = ClaudeSDKClient(
options=ClaudeAgentOptions(
@@ -167,20 +172,21 @@ class ExpandChatSession:
logger.exception("Failed to create Claude client")
yield {
"type": "error",
"content": f"Failed to initialize Claude: {str(e)}"
"content": "Failed to initialize Claude"
}
return
# Start the conversation
try:
async for chunk in self._query_claude("Begin the project expansion process."):
yield chunk
async with self._query_lock:
async for chunk in self._query_claude("Begin the project expansion process."):
yield chunk
yield {"type": "response_done"}
except Exception as e:
logger.exception("Failed to start expand chat")
yield {
"type": "error",
"content": f"Failed to start conversation: {str(e)}"
"content": "Failed to start conversation"
}
async def send_message(
@@ -218,14 +224,16 @@ class ExpandChatSession:
})
try:
async for chunk in self._query_claude(user_message, attachments):
yield chunk
# Use lock to prevent concurrent queries from corrupting the response stream
async with self._query_lock:
async for chunk in self._query_claude(user_message, attachments):
yield chunk
yield {"type": "response_done"}
except Exception as e:
logger.exception("Error during Claude query")
yield {
"type": "error",
"content": f"Error: {str(e)}"
"content": "Error while processing message"
}
async def _query_claude(
@@ -340,6 +348,10 @@ class ExpandChatSession:
Returns:
List of created feature dictionaries with IDs
Note:
Uses flush() to get IDs immediately without re-querying by priority range,
which could pick up rows from concurrent writers.
"""
# Import database classes
import sys
@@ -358,7 +370,7 @@ class ExpandChatSession:
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
created_features = []
created_rows: list = []
for f in features:
db_feature = Feature(
@@ -370,24 +382,28 @@ class ExpandChatSession:
passes=False,
)
session.add(db_feature)
created_rows.append(db_feature)
current_priority += 1
session.commit()
# Flush to get IDs without relying on priority range query
session.flush()
# Re-query to get the created features with IDs
start_priority = current_priority - len(features)
for db_feature in session.query(Feature).filter(
Feature.priority >= start_priority,
Feature.priority < current_priority
).order_by(Feature.priority).all():
created_features.append({
# Build result from the flushed objects (IDs are now populated)
created_features = [
{
"id": db_feature.id,
"name": db_feature.name,
"category": db_feature.category,
})
}
for db_feature in created_rows
]
session.commit()
return created_features
except Exception:
session.rollback()
raise
finally:
session.close()

1
server/utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Server utilities

View File

@@ -0,0 +1,28 @@
"""
Shared validation utilities for the server.
"""
import re
from fastapi import HTTPException
def validate_project_name(name: str) -> str:
"""
Validate and sanitize project name to prevent path traversal.
Args:
name: Project name to validate
Returns:
The validated project name
Raises:
HTTPException: If name is invalid
"""
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
raise HTTPException(
status_code=400,
detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)."
)
return name

View File

@@ -34,6 +34,9 @@ export function ExpandProjectChat({
const inputRef = useRef<HTMLInputElement>(null)
const fileInputRef = useRef<HTMLInputElement>(null)
// Memoize error handler to keep hook dependencies stable
const handleError = useCallback((err: string) => setError(err), [])
const {
messages,
isLoading,
@@ -46,7 +49,7 @@ export function ExpandProjectChat({
} = useExpandChat({
projectName,
onComplete,
onError: (err) => setError(err),
onError: handleError,
})
// Start the chat session when component mounts

View File

@@ -54,6 +54,7 @@ export function useExpandChat({
const pingIntervalRef = useRef<number | null>(null)
const reconnectTimeoutRef = useRef<number | null>(null)
const isCompleteRef = useRef(false)
const manuallyDisconnectedRef = useRef(false)
// Keep isCompleteRef in sync with isComplete state
useEffect(() => {
@@ -76,6 +77,10 @@ export function useExpandChat({
}, [])
const connect = useCallback(() => {
// Don't reconnect if manually disconnected
if (manuallyDisconnectedRef.current) {
return
}
if (wsRef.current?.readyState === WebSocket.OPEN) {
return
}
@@ -92,6 +97,7 @@ export function useExpandChat({
ws.onopen = () => {
setConnectionStatus('connected')
reconnectAttempts.current = 0
manuallyDisconnectedRef.current = false
// Start ping interval to keep connection alive
pingIntervalRef.current = window.setInterval(() => {
@@ -109,7 +115,11 @@ export function useExpandChat({
}
// Attempt reconnection if not intentionally closed
if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) {
if (
!manuallyDisconnectedRef.current &&
reconnectAttempts.current < maxReconnectAttempts &&
!isCompleteRef.current
) {
reconnectAttempts.current++
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000)
reconnectTimeoutRef.current = window.setTimeout(connect, delay)
@@ -244,18 +254,25 @@ export function useExpandChat({
const start = useCallback(() => {
connect()
// Wait for connection then send start message
// Wait for connection then send start message (with timeout to prevent infinite loop)
let attempts = 0
const maxAttempts = 50 // 5 seconds max (50 * 100ms)
const checkAndSend = () => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
setIsLoading(true)
wsRef.current.send(JSON.stringify({ type: 'start' }))
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
setTimeout(checkAndSend, 100)
if (attempts++ < maxAttempts) {
setTimeout(checkAndSend, 100)
} else {
onError?.('Connection timeout')
setIsLoading(false)
}
}
}
setTimeout(checkAndSend, 100)
}, [connect])
}, [connect, onError])
const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
@@ -297,6 +314,7 @@ export function useExpandChat({
}, [onError])
const disconnect = useCallback(() => {
manuallyDisconnectedRef.current = true
reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection
if (pingIntervalRef.current) {
clearInterval(pingIntervalRef.current)