mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-30 14:22:04 +00:00
fix: address second round of code review feedback
Backend improvements: - Create shared validation utility for project name validation - Add asyncio.Lock to prevent concurrent _query_claude calls - Fix _create_features_bulk: use flush() for IDs, add rollback on error - Use unique temp settings file instead of overwriting .claude_settings.json - Remove exception details from error messages (security) Frontend improvements: - Memoize onError callback in ExpandProjectChat for stable dependencies - Add timeout to start() checkAndSend loop to prevent infinite retries - Add manuallyDisconnectedRef to prevent reconnection after explicit disconnect - Clear pending reconnect timeout in disconnect() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,6 @@ Allows adding multiple features to existing projects via natural language.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -23,6 +22,7 @@ from ..services.expand_chat_session import (
|
||||
list_expand_sessions,
|
||||
remove_expand_session,
|
||||
)
|
||||
from ..utils.validation import validate_project_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,9 +43,6 @@ def _get_project_path(project_name: str) -> Path:
|
||||
return get_project_path(project_name)
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> bool:
|
||||
"""Validate project name to prevent path traversal."""
|
||||
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -70,8 +67,7 @@ async def list_expand_sessions_endpoint():
|
||||
@router.get("/sessions/{project_name}", response_model=ExpandSessionStatus)
|
||||
async def get_expand_session_status(project_name: str):
|
||||
"""Get status of an expansion session."""
|
||||
if not validate_project_name(project_name):
|
||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
||||
project_name = validate_project_name(project_name)
|
||||
|
||||
session = get_expand_session(project_name)
|
||||
if not session:
|
||||
@@ -89,8 +85,7 @@ async def get_expand_session_status(project_name: str):
|
||||
@router.delete("/sessions/{project_name}")
|
||||
async def cancel_expand_session(project_name: str):
|
||||
"""Cancel and remove an expansion session."""
|
||||
if not validate_project_name(project_name):
|
||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
||||
project_name = validate_project_name(project_name)
|
||||
|
||||
session = get_expand_session(project_name)
|
||||
if not session:
|
||||
@@ -124,7 +119,9 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str):
|
||||
- {"type": "error", "content": "..."} - Error message
|
||||
- {"type": "pong"} - Keep-alive pong
|
||||
"""
|
||||
if not validate_project_name(project_name):
|
||||
try:
|
||||
project_name = validate_project_name(project_name)
|
||||
except HTTPException:
|
||||
await websocket.close(code=4000, reason="Invalid project name")
|
||||
return
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ API endpoints for feature/test case management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -19,6 +18,7 @@ from ..schemas import (
|
||||
FeatureListResponse,
|
||||
FeatureResponse,
|
||||
)
|
||||
from ..utils.validation import validate_project_name
|
||||
|
||||
# Lazy imports to avoid circular dependencies
|
||||
_create_database = None
|
||||
@@ -56,16 +56,6 @@ def _get_db_classes():
|
||||
router = APIRouter(prefix="/api/projects/{project_name}/features", tags=["features"])
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> str:
|
||||
"""Validate and sanitize project name to prevent path traversal."""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid project name"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session(project_dir: Path):
|
||||
"""
|
||||
|
||||
@@ -6,11 +6,13 @@ Manages interactive project expansion conversation with Claude.
|
||||
Uses the expand-project.md skill to help users add features to existing projects.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Optional
|
||||
@@ -68,6 +70,7 @@ class ExpandChatSession:
|
||||
self.features_created: int = 0
|
||||
self.created_feature_ids: list[int] = []
|
||||
self._settings_file: Optional[Path] = None
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources and close the Claude client."""
|
||||
@@ -117,7 +120,16 @@ class ExpandChatSession:
|
||||
except UnicodeDecodeError:
|
||||
skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
|
||||
|
||||
# Create security settings file
|
||||
# Find and validate Claude CLI before creating temp files
|
||||
system_cli = shutil.which("claude")
|
||||
if not system_cli:
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": "Claude CLI not found. Please install Claude Code."
|
||||
}
|
||||
return
|
||||
|
||||
# Create temporary security settings file (unique per session to avoid conflicts)
|
||||
security_settings = {
|
||||
"sandbox": {"enabled": True},
|
||||
"permissions": {
|
||||
@@ -128,23 +140,16 @@ class ExpandChatSession:
|
||||
],
|
||||
},
|
||||
}
|
||||
settings_file = self.project_dir / ".claude_settings.json"
|
||||
settings_file = self.project_dir / f".claude_settings.expand.{uuid.uuid4().hex}.json"
|
||||
self._settings_file = settings_file
|
||||
with open(settings_file, "w") as f:
|
||||
with open(settings_file, "w", encoding="utf-8") as f:
|
||||
json.dump(security_settings, f, indent=2)
|
||||
|
||||
# Replace $ARGUMENTS with absolute project path
|
||||
project_path = str(self.project_dir.resolve())
|
||||
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
|
||||
|
||||
# Find and validate Claude CLI
|
||||
system_cli = shutil.which("claude")
|
||||
if not system_cli:
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": "Claude CLI not found. Please install Claude Code."
|
||||
}
|
||||
return
|
||||
# Create Claude SDK client
|
||||
try:
|
||||
self.client = ClaudeSDKClient(
|
||||
options=ClaudeAgentOptions(
|
||||
@@ -167,20 +172,21 @@ class ExpandChatSession:
|
||||
logger.exception("Failed to create Claude client")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Failed to initialize Claude: {str(e)}"
|
||||
"content": "Failed to initialize Claude"
|
||||
}
|
||||
return
|
||||
|
||||
# Start the conversation
|
||||
try:
|
||||
async for chunk in self._query_claude("Begin the project expansion process."):
|
||||
yield chunk
|
||||
async with self._query_lock:
|
||||
async for chunk in self._query_claude("Begin the project expansion process."):
|
||||
yield chunk
|
||||
yield {"type": "response_done"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start expand chat")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Failed to start conversation: {str(e)}"
|
||||
"content": "Failed to start conversation"
|
||||
}
|
||||
|
||||
async def send_message(
|
||||
@@ -218,14 +224,16 @@ class ExpandChatSession:
|
||||
})
|
||||
|
||||
try:
|
||||
async for chunk in self._query_claude(user_message, attachments):
|
||||
yield chunk
|
||||
# Use lock to prevent concurrent queries from corrupting the response stream
|
||||
async with self._query_lock:
|
||||
async for chunk in self._query_claude(user_message, attachments):
|
||||
yield chunk
|
||||
yield {"type": "response_done"}
|
||||
except Exception as e:
|
||||
logger.exception("Error during Claude query")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Error: {str(e)}"
|
||||
"content": "Error while processing message"
|
||||
}
|
||||
|
||||
async def _query_claude(
|
||||
@@ -340,6 +348,10 @@ class ExpandChatSession:
|
||||
|
||||
Returns:
|
||||
List of created feature dictionaries with IDs
|
||||
|
||||
Note:
|
||||
Uses flush() to get IDs immediately without re-querying by priority range,
|
||||
which could pick up rows from concurrent writers.
|
||||
"""
|
||||
# Import database classes
|
||||
import sys
|
||||
@@ -358,7 +370,7 @@ class ExpandChatSession:
|
||||
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
|
||||
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
|
||||
|
||||
created_features = []
|
||||
created_rows: list = []
|
||||
|
||||
for f in features:
|
||||
db_feature = Feature(
|
||||
@@ -370,24 +382,28 @@ class ExpandChatSession:
|
||||
passes=False,
|
||||
)
|
||||
session.add(db_feature)
|
||||
created_rows.append(db_feature)
|
||||
current_priority += 1
|
||||
|
||||
session.commit()
|
||||
# Flush to get IDs without relying on priority range query
|
||||
session.flush()
|
||||
|
||||
# Re-query to get the created features with IDs
|
||||
start_priority = current_priority - len(features)
|
||||
for db_feature in session.query(Feature).filter(
|
||||
Feature.priority >= start_priority,
|
||||
Feature.priority < current_priority
|
||||
).order_by(Feature.priority).all():
|
||||
created_features.append({
|
||||
# Build result from the flushed objects (IDs are now populated)
|
||||
created_features = [
|
||||
{
|
||||
"id": db_feature.id,
|
||||
"name": db_feature.name,
|
||||
"category": db_feature.category,
|
||||
})
|
||||
}
|
||||
for db_feature in created_rows
|
||||
]
|
||||
|
||||
session.commit()
|
||||
return created_features
|
||||
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
1
server/utils/__init__.py
Normal file
1
server/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Server utilities
|
||||
28
server/utils/validation.py
Normal file
28
server/utils/validation.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Shared validation utilities for the server.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> str:
|
||||
"""
|
||||
Validate and sanitize project name to prevent path traversal.
|
||||
|
||||
Args:
|
||||
name: Project name to validate
|
||||
|
||||
Returns:
|
||||
The validated project name
|
||||
|
||||
Raises:
|
||||
HTTPException: If name is invalid
|
||||
"""
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,50}$', name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid project name. Use only letters, numbers, hyphens, and underscores (1-50 chars)."
|
||||
)
|
||||
return name
|
||||
@@ -34,6 +34,9 @@ export function ExpandProjectChat({
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
// Memoize error handler to keep hook dependencies stable
|
||||
const handleError = useCallback((err: string) => setError(err), [])
|
||||
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
@@ -46,7 +49,7 @@ export function ExpandProjectChat({
|
||||
} = useExpandChat({
|
||||
projectName,
|
||||
onComplete,
|
||||
onError: (err) => setError(err),
|
||||
onError: handleError,
|
||||
})
|
||||
|
||||
// Start the chat session when component mounts
|
||||
|
||||
@@ -54,6 +54,7 @@ export function useExpandChat({
|
||||
const pingIntervalRef = useRef<number | null>(null)
|
||||
const reconnectTimeoutRef = useRef<number | null>(null)
|
||||
const isCompleteRef = useRef(false)
|
||||
const manuallyDisconnectedRef = useRef(false)
|
||||
|
||||
// Keep isCompleteRef in sync with isComplete state
|
||||
useEffect(() => {
|
||||
@@ -76,6 +77,10 @@ export function useExpandChat({
|
||||
}, [])
|
||||
|
||||
const connect = useCallback(() => {
|
||||
// Don't reconnect if manually disconnected
|
||||
if (manuallyDisconnectedRef.current) {
|
||||
return
|
||||
}
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
return
|
||||
}
|
||||
@@ -92,6 +97,7 @@ export function useExpandChat({
|
||||
ws.onopen = () => {
|
||||
setConnectionStatus('connected')
|
||||
reconnectAttempts.current = 0
|
||||
manuallyDisconnectedRef.current = false
|
||||
|
||||
// Start ping interval to keep connection alive
|
||||
pingIntervalRef.current = window.setInterval(() => {
|
||||
@@ -109,7 +115,11 @@ export function useExpandChat({
|
||||
}
|
||||
|
||||
// Attempt reconnection if not intentionally closed
|
||||
if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) {
|
||||
if (
|
||||
!manuallyDisconnectedRef.current &&
|
||||
reconnectAttempts.current < maxReconnectAttempts &&
|
||||
!isCompleteRef.current
|
||||
) {
|
||||
reconnectAttempts.current++
|
||||
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000)
|
||||
reconnectTimeoutRef.current = window.setTimeout(connect, delay)
|
||||
@@ -244,18 +254,25 @@ export function useExpandChat({
|
||||
const start = useCallback(() => {
|
||||
connect()
|
||||
|
||||
// Wait for connection then send start message
|
||||
// Wait for connection then send start message (with timeout to prevent infinite loop)
|
||||
let attempts = 0
|
||||
const maxAttempts = 50 // 5 seconds max (50 * 100ms)
|
||||
const checkAndSend = () => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
setIsLoading(true)
|
||||
wsRef.current.send(JSON.stringify({ type: 'start' }))
|
||||
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
|
||||
setTimeout(checkAndSend, 100)
|
||||
if (attempts++ < maxAttempts) {
|
||||
setTimeout(checkAndSend, 100)
|
||||
} else {
|
||||
onError?.('Connection timeout')
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(checkAndSend, 100)
|
||||
}, [connect])
|
||||
}, [connect, onError])
|
||||
|
||||
const sendMessage = useCallback((content: string, attachments?: ImageAttachment[]) => {
|
||||
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
|
||||
@@ -297,6 +314,7 @@ export function useExpandChat({
|
||||
}, [onError])
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
manuallyDisconnectedRef.current = true
|
||||
reconnectAttempts.current = maxReconnectAttempts // Prevent reconnection
|
||||
if (pingIntervalRef.current) {
|
||||
clearInterval(pingIntervalRef.current)
|
||||
|
||||
Reference in New Issue
Block a user