mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-31 14:43:35 +00:00
fix: address second round of code review feedback
Backend improvements: - Create shared validation utility for project name validation - Add asyncio.Lock to prevent concurrent _query_claude calls - Fix _create_features_bulk: use flush() for IDs, add rollback on error - Use unique temp settings file instead of overwriting .claude_settings.json - Remove exception details from error messages (security) Frontend improvements: - Memoize onError callback in ExpandProjectChat for stable dependencies - Add timeout to start() checkAndSend loop to prevent infinite retries - Add manuallyDisconnectedRef to prevent reconnection after explicit disconnect - Clear pending reconnect timeout in disconnect() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,6 @@ Allows adding multiple features to existing projects via natural language.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -23,6 +22,7 @@ from ..services.expand_chat_session import (
|
|||||||
list_expand_sessions,
|
list_expand_sessions,
|
||||||
remove_expand_session,
|
remove_expand_session,
|
||||||
)
|
)
|
||||||
|
from ..utils.validation import validate_project_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -43,9 +43,6 @@ def _get_project_path(project_name: str) -> Path:
|
|||||||
return get_project_path(project_name)
|
return get_project_path(project_name)
|
||||||
|
|
||||||
|
|
||||||
def validate_project_name(name: str) -> bool:
|
|
||||||
"""Validate project name to prevent path traversal."""
|
|
||||||
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -70,8 +67,7 @@ async def list_expand_sessions_endpoint():
|
|||||||
@router.get("/sessions/{project_name}", response_model=ExpandSessionStatus)
|
@router.get("/sessions/{project_name}", response_model=ExpandSessionStatus)
|
||||||
async def get_expand_session_status(project_name: str):
|
async def get_expand_session_status(project_name: str):
|
||||||
"""Get status of an expansion session."""
|
"""Get status of an expansion session."""
|
||||||
if not validate_project_name(project_name):
|
project_name = validate_project_name(project_name)
|
||||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
|
||||||
|
|
||||||
session = get_expand_session(project_name)
|
session = get_expand_session(project_name)
|
||||||
if not session:
|
if not session:
|
||||||
@@ -89,8 +85,7 @@ async def get_expand_session_status(project_name: str):
|
|||||||
@router.delete("/sessions/{project_name}")
|
@router.delete("/sessions/{project_name}")
|
||||||
async def cancel_expand_session(project_name: str):
|
async def cancel_expand_session(project_name: str):
|
||||||
"""Cancel and remove an expansion session."""
|
"""Cancel and remove an expansion session."""
|
||||||
if not validate_project_name(project_name):
|
project_name = validate_project_name(project_name)
|
||||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
|
||||||
|
|
||||||
session = get_expand_session(project_name)
|
session = get_expand_session(project_name)
|
||||||
if not session:
|
if not session:
|
||||||
@@ -124,7 +119,9 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str):
|
|||||||
- {"type": "error", "content": "..."} - Error message
|
- {"type": "error", "content": "..."} - Error message
|
||||||
- {"type": "pong"} - Keep-alive pong
|
- {"type": "pong"} - Keep-alive pong
|
||||||
"""
|
"""
|
||||||
if not validate_project_name(project_name):
|
try:
|
||||||
|
project_name = validate_project_name(project_name)
|
||||||
|
except HTTPException:
|
||||||
await websocket.close(code=4000, reason="Invalid project name")
|
await websocket.close(code=4000, reason="Invalid project name")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ API endpoints for feature/test case management.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -19,6 +18,7 @@ from ..schemas import (
|
|||||||
FeatureListResponse,
|
FeatureListResponse,
|
||||||
FeatureResponse,
|
FeatureResponse,
|
||||||
)
|
)
|
||||||
|
from ..utils.validation import validate_project_name
|
||||||
|
|
||||||
# Lazy imports to avoid circular dependencies
|
# Lazy imports to avoid circular dependencies
|
||||||
_create_database = None
|
_create_database = None
|
||||||
@@ -56,16 +56,6 @@ def _get_db_classes():
|
|||||||
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
|
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
|
||||||
|
|
||||||
|
|
||||||
def validate_project_name(name: str) -> str:
|
|
||||||
"""Validate and sanitize project name to prevent path traversal."""
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid project name"
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_session(project_dir: Path):
|
def get_db_session(project_dir: Path):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ Manages interactive project expansion conversation with Claude.
|
|||||||
Uses the expand-project.md skill to help users add features to existing projects.
|
Uses the expand-project.md skill to help users add features to existing projects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator, Optional
|
||||||
@@ -68,6 +70,7 @@ class ExpandChatSession:
|
|||||||
self.features_created: int = 0
|
self.features_created: int = 0
|
||||||
self.created_feature_ids: list[int] = []
|
self.created_feature_ids: list[int] = []
|
||||||
self._settings_file: Optional[Path] = None
|
self._settings_file: Optional[Path] = None
|
||||||
|
self._query_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Clean up resources and close the Claude client."""
|
"""Clean up resources and close the Claude client."""
|
||||||
@@ -117,7 +120,16 @@ class ExpandChatSession:
|
|||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
|
skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
|
||||||
# Create security settings file
|
# Find and validate Claude CLI before creating temp files
|
||||||
|
system_cli = shutil.which("claude")
|
||||||
|
if not system_cli:
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"content": "Claude CLI not found. Please install Claude Code."
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create temporary security settings file (unique per session to avoid conflicts)
|
||||||
security_settings = {
|
security_settings = {
|
||||||
"sandbox": {"enabled": True},
|
"sandbox": {"enabled": True},
|
||||||
"permissions": {
|
"permissions": {
|
||||||
@@ -128,23 +140,16 @@ class ExpandChatSession:
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
settings_file = self.project_dir / ".claude_settings.json"
|
settings_file = self.project_dir / f".claude_settings.expand.{uuid.uuid4().hex}.json"
|
||||||
self._settings_file = settings_file
|
self._settings_file = settings_file
|
||||||
with open(settings_file, "w") as f:
|
with open(settings_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(security_settings, f, indent=2)
|
json.dump(security_settings, f, indent=2)
|
||||||
|
|
||||||
# Replace $ARGUMENTS with absolute project path
|
# Replace $ARGUMENTS with absolute project path
|
||||||
project_path = str(self.project_dir.resolve())
|
project_path = str(self.project_dir.resolve())
|
||||||
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
|
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
|
||||||
|
|
||||||
# Find and validate Claude CLI
|
# Create Claude SDK client
|
||||||
system_cli = shutil.which("claude")
|
|
||||||
if not system_cli:
|
|
||||||
yield {
|
|
||||||
"type": "error",
|
|
||||||
"content": "Claude CLI not found. Please install Claude Code."
|
|
||||||
}
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
self.client = ClaudeSDKClient(
|
self.client = ClaudeSDKClient(
|
||||||
options=ClaudeAgentOptions(
|
options=ClaudeAgentOptions(
|
||||||
@@ -167,20 +172,21 @@ class ExpandChatSession:
|
|||||||
logger.exception("Failed to create Claude client")
|
logger.exception("Failed to create Claude client")
|
||||||
yield {
|
yield {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"content": f"Failed to initialize Claude: {str(e)}"
|
"content": "Failed to initialize Claude"
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start the conversation
|
# Start the conversation
|
||||||
try:
|
try:
|
||||||
async for chunk in self._query_claude("Begin the project expansion process."):
|
async with self._query_lock:
|
||||||
yield chunk
|
async for chunk in self._query_claude("Begin the project expansion process."):
|
||||||
|
yield chunk
|
||||||
yield {"type": "response_done"}
|
yield {"type": "response_done"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to start expand chat")
|
logger.exception("Failed to start expand chat")
|
||||||
yield {
|
yield {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"content": f"Failed to start conversation: {str(e)}"
|
"content": "Failed to start conversation"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
@@ -218,14 +224,16 @@ class ExpandChatSession:
|
|||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in self._query_claude(user_message, attachments):
|
# Use lock to prevent concurrent queries from corrupting the response stream
|
||||||
yield chunk
|
async with self._query_lock:
|
||||||
|
async for chunk in self._query_claude(user_message, attachments):
|
||||||
|
yield chunk
|
||||||
yield {"type": "response_done"}
|
yield {"type": "response_done"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during Claude query")
|
logger.exception("Error during Claude query")
|
||||||
yield {
|
yield {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"content": f"Error: {str(e)}"
|
"content": "Error while processing message"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _query_claude(
|
async def _query_claude(
|
||||||
@@ -340,6 +348,10 @@ class ExpandChatSession:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of created feature dictionaries with IDs
|
List of created feature dictionaries with IDs
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses flush() to get IDs immediately without re-querying by priority range,
|
||||||
|
which could pick up rows from concurrent writers.
|
||||||
"""
|
"""
|
||||||
# Import database classes
|
# Import database classes
|
||||||
import sys
|
import sys
|
||||||
@@ -358,7 +370,7 @@ class ExpandChatSession:
|
|||||||
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
|
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
|
||||||
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
|
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
|
||||||
|
|
||||||
created_features = []
|
created_rows: list = []
|
||||||
|
|
||||||
for f in features:
|
for f in features:
|
||||||
db_feature = Feature(
|
db_feature = Feature(
|
||||||
@@ -370,24 +382,28 @@ class ExpandChatSession:
|
|||||||
passes=False,
|
passes=False,
|
||||||
)
|
)
|
||||||
session.add(db_feature)
|
session.add(db_feature)
|
||||||
|
created_rows.append(db_feature)
|
||||||
current_priority += 1
|
current_priority += 1
|
||||||
|
|
||||||
session.commit()
|
# Flush to get IDs without relying on priority range query
|
||||||
|
session.flush()
|
||||||
|
|
||||||
# Re-query to get the created features with IDs
|
# Build result from the flushed objects (IDs are now populated)
|
||||||
start_priority = current_priority - len(features)
|
created_features = [
|
||||||
for db_feature in session.query(Feature).filter(
|
{
|
||||||
Feature.priority >= start_priority,
|
|
||||||
Feature.priority < current_priority
|
|
||||||
).order_by(Feature.priority).all():
|
|
||||||
created_features.append({
|
|
||||||
"id": db_feature.id,
|
"id": db_feature.id,
|
||||||
"name": db_feature.name,
|
"name": db_feature.name,
|
||||||
"category": db_feature.category,
|
"category": db_feature.category,
|
||||||
})
|
}
|
||||||
|
for db_feature in created_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
session.commit()
|
||||||
return created_features
|
return created_features
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|||||||
1
server/utils/__init__.py
Normal file
1
server/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Server utilities
|
||||||
28
server/utils/validation.py
Normal file
28
server/utils/validation.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Shared validation utilities for the server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def validate_project_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate and sanitize project name to prevent path traversal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Project name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated project name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If name is invalid
|
||||||
|
"""
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)."
|
||||||
|
)
|
||||||
|
return name
|
||||||
@@ -34,6 +34,9 @@ export function ExpandProjectChat({
|
|||||||
const inputRef = useRef<HTMLInputElement>(null)
|
const inputRef = useRef<HTMLInputElement>(null)
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
|
// Memoize error handler to keep hook dependencies stable
|
||||||
|
const handleError = useCallback((err: string) => setError(err), [])
|
||||||
|
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
@@ -46,7 +49,7 @@ export function ExpandProjectChat({
|
|||||||
} = useExpandChat({
|
} = useExpandChat({
|
||||||
projectName,
|
projectName,
|
||||||
onComplete,
|
onComplete,
|
||||||
onError: (err) => setError(err),
|
onError: handleError,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start the chat session when component mounts
|
// Start the chat session when component mounts
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ export function useExpandChat({
|
|||||||
const pingIntervalRef = useRef<number | null>(null)
|
const pingIntervalRef = useRef<number | null>(null)
|
||||||
const reconnectTimeoutRef = useRef<number | null>(null)
|
const reconnectTimeoutRef = useRef<number | null>(null)
|
||||||
const isCompleteRef = useRef(false)
|
const isCompleteRef = useRef(false)
|
||||||
|
const manuallyDisconnectedRef = useRef(false)
|
||||||
|
|
||||||
// Keep isCompleteRef in sync with isComplete state
|
// Keep isCompleteRef in sync with isComplete state
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -76,6 +77,10 @@ export function useExpandChat({
|
|||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const connect = useCallback(() => {
|
const connect = useCallback(() => {
|
||||||
|
// Don't reconnect if manually disconnected
|
||||||
|
if (manuallyDisconnectedRef.current) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -92,6 +97,7 @@ export function useExpandChat({
|
|||||||
ws.onopen = () => {
|
ws.onopen = () => {
|
||||||
setConnectionStatus('connected')
|
setConnectionStatus('connected')
|
||||||
reconnectAttempts.current = 0
|
reconnectAttempts.current = 0
|
||||||
|
manuallyDisconnectedRef.current = false
|
||||||
|
|
||||||
// Start ping interval to keep connection alive
|
// Start ping interval to keep connection alive
|
||||||
pingIntervalRef.current = window.setInterval(() => {
|
pingIntervalRef.current = window.setInterval(() => {
|
||||||
@@ -109,7 +115,11 @@ export function useExpandChat({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Attempt reconnection if not intentionally closed
|
// Attempt reconnection if not intentionally closed
|
||||||
if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) {
|
if (
|
||||||
|
!manuallyDisconnectedRef.current &&
|
||||||
|
reconnectAttempts.current < maxReconnectAttempts &&
|
||||||
|
!isCompleteRef.current
|
||||||
|
) {
|
||||||
reconnectAttempts.current++
|
reconnectAttempts.current++
|
||||||
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000)
|
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000)
|
||||||
reconnectTimeoutRef.current = window.setTimeout(connect, delay)
|
reconnectTimeoutRef.current = window.setTimeout(connect, delay)
|
||||||
@@ -244,18 +254,25 @@ export function useExpandChat({
|
|||||||
const start = useCallback(() => {
|
const start = useCallback(() => {
|
||||||
connect()
|
connect()
|
||||||
|
|
||||||
// Wait for connection then send start message
|
// Wait for connection then send start message (with timeout to prevent infinite loop)
|
||||||
|
let attempts = 0
|
||||||
|
const maxAttempts = 50 // 5 seconds max (50 * 100ms)
|
||||||
const checkAndSend = () => {
|
const checkAndSend = () => {
|
||||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
wsRef.current.send(JSON.stringify({ type: 'start' }))
|
wsRef.current.send(JSON.stringify({ type: 'start' }))
|
||||||
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
|
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
|
||||||
setTimeout(checkAndSend, 100)
|
if (attempts++ < maxAttempts) {
|
||||||
|
setTimeout(checkAndSend, 100)
|
||||||
|
} else {
|
||||||
|
onError?.('Connection timeout')
|
||||||
|
setIsLoading(false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setTimeout(checkAndSend, 100)
|
setTimeout(checkAndSend, 100)
|
||||||
}, [connect])
|
}, [connect, onError])
|
||||||
|
|
||||||
const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => {
|
const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => {
|
||||||
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
|
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
|
||||||
@@ -297,6 +314,7 @@ export function useExpandChat({
|
|||||||
}, [onError])
|
}, [onError])
|
||||||
|
|
||||||
const disconnect = useCallback(() => {
|
const disconnect = useCallback(() => {
|
||||||
|
manuallyDisconnectedRef.current = true
|
||||||
reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection
|
reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection
|
||||||
if (pingIntervalRef.current) {
|
if (pingIntervalRef.current) {
|
||||||
clearInterval(pingIntervalRef.current)
|
clearInterval(pingIntervalRef.current)
|
||||||
|
|||||||
Reference in New Issue
Block a user