mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-03-17 19:03:09 +00:00
add claude spec generation
This commit is contained in:
@@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .routers import projects_router, features_router, agent_router
|
||||
from .routers import projects_router, features_router, agent_router, spec_creation_router
|
||||
from .websocket import project_websocket
|
||||
from .services.process_manager import cleanup_all_managers
|
||||
from .schemas import SetupStatus
|
||||
@@ -81,6 +81,7 @@ async def require_localhost(request: Request, call_next):
|
||||
app.include_router(projects_router)
|
||||
app.include_router(features_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(spec_creation_router)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -8,5 +8,6 @@ FastAPI routers for different API endpoints.
|
||||
from .projects import router as projects_router
|
||||
from .features import router as features_router
|
||||
from .agent import router as agent_router
|
||||
from .spec_creation import router as spec_creation_router
|
||||
|
||||
__all__ = ["projects_router", "features_router", "agent_router"]
|
||||
__all__ = ["projects_router", "features_router", "agent_router", "spec_creation_router"]
|
||||
|
||||
233
server/routers/spec_creation.py
Normal file
233
server/routers/spec_creation.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Spec Creation Router
|
||||
====================
|
||||
|
||||
WebSocket and REST endpoints for interactive spec creation with Claude.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..services.spec_chat_session import (
|
||||
SpecChatSession,
|
||||
get_session,
|
||||
create_session,
|
||||
remove_session,
|
||||
list_sessions,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/spec", tags=["spec-creation"])
|
||||
|
||||
# Root directory
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def validate_project_name(name: str) -> bool:
|
||||
"""Validate project name to prevent path traversal."""
|
||||
return bool(re.match(r'^[a-zA-Z0-9_-]{1,50}$', name))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# REST Endpoints
|
||||
# ============================================================================
|
||||
|
||||
class SpecSessionStatus(BaseModel):
|
||||
"""Status of a spec creation session."""
|
||||
project_name: str
|
||||
is_active: bool
|
||||
is_complete: bool
|
||||
message_count: int
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[str])
|
||||
async def list_spec_sessions():
|
||||
"""List all active spec creation sessions."""
|
||||
return list_sessions()
|
||||
|
||||
|
||||
@router.get("/sessions/{project_name}", response_model=SpecSessionStatus)
|
||||
async def get_session_status(project_name: str):
|
||||
"""Get status of a spec creation session."""
|
||||
if not validate_project_name(project_name):
|
||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
||||
|
||||
session = get_session(project_name)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="No active session for this project")
|
||||
|
||||
return SpecSessionStatus(
|
||||
project_name=project_name,
|
||||
is_active=True,
|
||||
is_complete=session.is_complete(),
|
||||
message_count=len(session.get_messages()),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/sessions/{project_name}")
|
||||
async def cancel_session(project_name: str):
|
||||
"""Cancel and remove a spec creation session."""
|
||||
if not validate_project_name(project_name):
|
||||
raise HTTPException(status_code=400, detail="Invalid project name")
|
||||
|
||||
session = get_session(project_name)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="No active session for this project")
|
||||
|
||||
await remove_session(project_name)
|
||||
return {"success": True, "message": "Session cancelled"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@router.websocket("/ws/{project_name}")
|
||||
async def spec_chat_websocket(websocket: WebSocket, project_name: str):
|
||||
"""
|
||||
WebSocket endpoint for interactive spec creation chat.
|
||||
|
||||
Message protocol:
|
||||
|
||||
Client -> Server:
|
||||
- {"type": "start"} - Start the spec creation session
|
||||
- {"type": "message", "content": "..."} - Send user message
|
||||
- {"type": "answer", "answers": {...}, "tool_id": "..."} - Answer structured question
|
||||
- {"type": "ping"} - Keep-alive ping
|
||||
|
||||
Server -> Client:
|
||||
- {"type": "text", "content": "..."} - Text chunk from Claude
|
||||
- {"type": "question", "questions": [...], "tool_id": "..."} - Structured question
|
||||
- {"type": "spec_complete", "path": "..."} - Spec file created
|
||||
- {"type": "file_written", "path": "..."} - Other file written
|
||||
- {"type": "complete"} - Session complete
|
||||
- {"type": "error", "content": "..."} - Error message
|
||||
- {"type": "pong"} - Keep-alive pong
|
||||
"""
|
||||
if not validate_project_name(project_name):
|
||||
await websocket.close(code=4000, reason="Invalid project name")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
session: Optional[SpecChatSession] = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
continue
|
||||
|
||||
elif msg_type == "start":
|
||||
# Create and start a new session
|
||||
session = await create_session(project_name)
|
||||
|
||||
# Stream the initial greeting
|
||||
async for chunk in session.start():
|
||||
await websocket.send_json(chunk)
|
||||
|
||||
# Check for completion
|
||||
if chunk.get("type") == "spec_complete":
|
||||
await websocket.send_json({"type": "complete"})
|
||||
|
||||
elif msg_type == "message":
|
||||
# User sent a message
|
||||
if not session:
|
||||
session = get_session(project_name)
|
||||
if not session:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": "No active session. Send 'start' first."
|
||||
})
|
||||
continue
|
||||
|
||||
user_content = message.get("content", "").strip()
|
||||
if not user_content:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": "Empty message"
|
||||
})
|
||||
continue
|
||||
|
||||
# Stream Claude's response
|
||||
async for chunk in session.send_message(user_content):
|
||||
await websocket.send_json(chunk)
|
||||
|
||||
# Check for completion
|
||||
if chunk.get("type") == "spec_complete":
|
||||
await websocket.send_json({"type": "complete"})
|
||||
|
||||
elif msg_type == "answer":
|
||||
# User answered a structured question
|
||||
if not session:
|
||||
session = get_session(project_name)
|
||||
if not session:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": "No active session"
|
||||
})
|
||||
continue
|
||||
|
||||
# Format the answers as a natural response
|
||||
answers = message.get("answers", {})
|
||||
if isinstance(answers, dict):
|
||||
# Convert structured answers to a message
|
||||
response_parts = []
|
||||
for question_idx, answer_value in answers.items():
|
||||
if isinstance(answer_value, list):
|
||||
response_parts.append(", ".join(answer_value))
|
||||
else:
|
||||
response_parts.append(str(answer_value))
|
||||
user_response = "; ".join(response_parts) if response_parts else "OK"
|
||||
else:
|
||||
user_response = str(answers)
|
||||
|
||||
# Stream Claude's response
|
||||
async for chunk in session.send_message(user_response):
|
||||
await websocket.send_json(chunk)
|
||||
|
||||
if chunk.get("type") == "spec_complete":
|
||||
await websocket.send_json({"type": "complete"})
|
||||
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": f"Unknown message type: {msg_type}"
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": "Invalid JSON"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"Spec chat WebSocket disconnected for {project_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Spec chat WebSocket error for {project_name}")
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": f"Server error: {str(e)}"
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Don't remove the session on disconnect - allow resume
|
||||
pass
|
||||
322
server/services/spec_chat_session.py
Normal file
322
server/services/spec_chat_session.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Spec Creation Chat Session
|
||||
==========================
|
||||
|
||||
Manages interactive spec creation conversation with Claude.
|
||||
Uses the create-spec.md skill to guide users through app spec creation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Root directory of the project
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class SpecChatSession:
|
||||
"""
|
||||
Manages a spec creation conversation for one project.
|
||||
|
||||
Uses the create-spec skill to guide users through:
|
||||
- Phase 1: Project Overview (name, description, audience)
|
||||
- Phase 2: Involvement Level (Quick vs Detailed mode)
|
||||
- Phase 3: Technology Preferences
|
||||
- Phase 4: Features (main exploration phase)
|
||||
- Phase 5: Technical Details (derived or discussed)
|
||||
- Phase 6-7: Success Criteria & Approval
|
||||
"""
|
||||
|
||||
def __init__(self, project_name: str):
|
||||
"""
|
||||
Initialize the session.
|
||||
|
||||
Args:
|
||||
project_name: Name of the project being created
|
||||
"""
|
||||
self.project_name = project_name
|
||||
self.project_dir = ROOT_DIR / "generations" / project_name
|
||||
self.client: Optional[ClaudeSDKClient] = None
|
||||
self.messages: list[dict] = []
|
||||
self.complete: bool = False
|
||||
self.created_at = datetime.now()
|
||||
self._conversation_id: Optional[str] = None
|
||||
self._client_entered: bool = False # Track if context manager is active
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources and close the Claude client."""
|
||||
if self.client and self._client_entered:
|
||||
try:
|
||||
await self.client.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Claude client: {e}")
|
||||
finally:
|
||||
self._client_entered = False
|
||||
self.client = None
|
||||
|
||||
async def start(self) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Initialize session and get initial greeting from Claude.
|
||||
|
||||
Yields message chunks as they stream in.
|
||||
"""
|
||||
# Load the create-spec skill
|
||||
skill_path = ROOT_DIR / ".claude" / "commands" / "create-spec.md"
|
||||
|
||||
if not skill_path.exists():
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Spec creation skill not found at {skill_path}"
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
skill_content = skill_path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
skill_content = skill_path.read_text(encoding="utf-8", errors="replace")
|
||||
|
||||
# Replace $ARGUMENTS with the project path (use forward slashes for consistency)
|
||||
project_path = f"generations/{self.project_name}"
|
||||
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
|
||||
|
||||
# Create Claude SDK client with limited tools for spec creation
|
||||
try:
|
||||
self.client = ClaudeSDKClient(
|
||||
options=ClaudeAgentOptions(
|
||||
model="claude-sonnet-4-20250514",
|
||||
system_prompt=system_prompt,
|
||||
allowed_tools=[
|
||||
"Read",
|
||||
"Write",
|
||||
"AskUserQuestion",
|
||||
],
|
||||
max_turns=100,
|
||||
cwd=str(ROOT_DIR.resolve()),
|
||||
)
|
||||
)
|
||||
# Enter the async context and track it
|
||||
await self.client.__aenter__()
|
||||
self._client_entered = True
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create Claude client")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Failed to initialize Claude: {str(e)}"
|
||||
}
|
||||
return
|
||||
|
||||
# Start the conversation - Claude will send the Phase 1 greeting
|
||||
try:
|
||||
async for chunk in self._query_claude("Begin the spec creation process."):
|
||||
yield chunk
|
||||
# Signal that the response is complete (for UI to hide loading indicator)
|
||||
yield {"type": "response_done"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start spec chat")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Failed to start conversation: {str(e)}"
|
||||
}
|
||||
|
||||
async def send_message(self, user_message: str) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Send user message and stream Claude's response.
|
||||
|
||||
Args:
|
||||
user_message: The user's response
|
||||
|
||||
Yields:
|
||||
Message chunks of various types:
|
||||
- {"type": "text", "content": str}
|
||||
- {"type": "question", "questions": list}
|
||||
- {"type": "spec_complete", "path": str}
|
||||
- {"type": "error", "content": str}
|
||||
"""
|
||||
if not self.client:
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": "Session not initialized. Call start() first."
|
||||
}
|
||||
return
|
||||
|
||||
# Store the user message
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
try:
|
||||
async for chunk in self._query_claude(user_message):
|
||||
yield chunk
|
||||
# Signal that the response is complete (for UI to hide loading indicator)
|
||||
yield {"type": "response_done"}
|
||||
except Exception as e:
|
||||
logger.exception("Error during Claude query")
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Error: {str(e)}"
|
||||
}
|
||||
|
||||
async def _query_claude(self, message: str) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Internal method to query Claude and stream responses.
|
||||
|
||||
Handles tool calls (AskUserQuestion, Write) and text responses.
|
||||
"""
|
||||
if not self.client:
|
||||
return
|
||||
|
||||
# Send the message to Claude using the SDK's query method
|
||||
await self.client.query(message)
|
||||
|
||||
current_text = ""
|
||||
|
||||
# Stream the response using receive_response
|
||||
async for msg in self.client.receive_response():
|
||||
msg_type = type(msg).__name__
|
||||
|
||||
if msg_type == "AssistantMessage" and hasattr(msg, "content"):
|
||||
# Process content blocks in the assistant message
|
||||
for block in msg.content:
|
||||
block_type = type(block).__name__
|
||||
|
||||
if block_type == "TextBlock" and hasattr(block, "text"):
|
||||
# Accumulate text and yield it
|
||||
text = block.text
|
||||
if text:
|
||||
current_text += text
|
||||
yield {"type": "text", "content": text}
|
||||
|
||||
# Store in message history
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
elif block_type == "ToolUseBlock" and hasattr(block, "name"):
|
||||
tool_name = block.name
|
||||
tool_input = getattr(block, "input", {})
|
||||
tool_id = getattr(block, "id", "")
|
||||
|
||||
if tool_name == "AskUserQuestion":
|
||||
# Convert AskUserQuestion to structured UI
|
||||
questions = tool_input.get("questions", [])
|
||||
yield {
|
||||
"type": "question",
|
||||
"questions": questions,
|
||||
"tool_id": tool_id
|
||||
}
|
||||
# The SDK handles tool results internally
|
||||
|
||||
elif tool_name == "Write":
|
||||
# File being written - the SDK handles this
|
||||
file_path = tool_input.get("file_path", "")
|
||||
|
||||
# Check if this is the app_spec.txt file
|
||||
if "app_spec.txt" in str(file_path):
|
||||
self.complete = True
|
||||
yield {
|
||||
"type": "spec_complete",
|
||||
"path": str(file_path)
|
||||
}
|
||||
|
||||
elif "initializer_prompt.md" in str(file_path):
|
||||
yield {
|
||||
"type": "file_written",
|
||||
"path": str(file_path)
|
||||
}
|
||||
|
||||
elif msg_type == "UserMessage" and hasattr(msg, "content"):
|
||||
# Tool results - the SDK handles these automatically
|
||||
# We just watch for any errors
|
||||
for block in msg.content:
|
||||
block_type = type(block).__name__
|
||||
if block_type == "ToolResultBlock":
|
||||
is_error = getattr(block, "is_error", False)
|
||||
if is_error:
|
||||
content = getattr(block, "content", "Unknown error")
|
||||
logger.warning(f"Tool error: {content}")
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if spec creation is complete."""
|
||||
return self.complete
|
||||
|
||||
def get_messages(self) -> list[dict]:
|
||||
"""Get all messages in the conversation."""
|
||||
return self.messages.copy()
|
||||
|
||||
|
||||
# Session registry with thread safety
|
||||
_sessions: dict[str, SpecChatSession] = {}
|
||||
_sessions_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_session(project_name: str) -> Optional[SpecChatSession]:
|
||||
"""Get an existing session for a project."""
|
||||
with _sessions_lock:
|
||||
return _sessions.get(project_name)
|
||||
|
||||
|
||||
async def create_session(project_name: str) -> SpecChatSession:
|
||||
"""Create a new session for a project, closing any existing one."""
|
||||
old_session: Optional[SpecChatSession] = None
|
||||
|
||||
with _sessions_lock:
|
||||
# Get existing session to close later (outside the lock)
|
||||
old_session = _sessions.pop(project_name, None)
|
||||
session = SpecChatSession(project_name)
|
||||
_sessions[project_name] = session
|
||||
|
||||
# Close old session outside the lock to avoid blocking
|
||||
if old_session:
|
||||
try:
|
||||
await old_session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing old session for {project_name}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def remove_session(project_name: str) -> None:
|
||||
"""Remove and close a session."""
|
||||
session: Optional[SpecChatSession] = None
|
||||
|
||||
with _sessions_lock:
|
||||
session = _sessions.pop(project_name, None)
|
||||
|
||||
# Close session outside the lock
|
||||
if session:
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing session for {project_name}: {e}")
|
||||
|
||||
|
||||
def list_sessions() -> list[str]:
|
||||
"""List all active session project names."""
|
||||
with _sessions_lock:
|
||||
return list(_sessions.keys())
|
||||
|
||||
|
||||
async def cleanup_all_sessions() -> None:
|
||||
"""Close all active sessions. Called on server shutdown."""
|
||||
sessions_to_close: list[SpecChatSession] = []
|
||||
|
||||
with _sessions_lock:
|
||||
sessions_to_close = list(_sessions.values())
|
||||
_sessions.clear()
|
||||
|
||||
for session in sessions_to_close:
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing session {session.project_name}: {e}")
|
||||
Reference in New Issue
Block a user