mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-30 06:12:06 +00:00
Merge pull request #124 from rudiheydra/fix/expand-session-mcp-and-scheduling
fix: expand session MCP wiring + scheduling infinite loop
This commit is contained in:
@@ -300,15 +300,20 @@ def compute_scheduling_scores(features: list[dict]) -> dict[int, float]:
|
|||||||
parents[f["id"]].append(dep_id)
|
parents[f["id"]].append(dep_id)
|
||||||
|
|
||||||
# Calculate depths via BFS from roots
|
# Calculate depths via BFS from roots
|
||||||
|
# Use visited set to prevent infinite loops from circular dependencies
|
||||||
depths: dict[int, int] = {}
|
depths: dict[int, int] = {}
|
||||||
|
visited: set[int] = set()
|
||||||
roots = [f["id"] for f in features if not parents[f["id"]]]
|
roots = [f["id"] for f in features if not parents[f["id"]]]
|
||||||
queue = [(root, 0) for root in roots]
|
queue = [(root, 0) for root in roots]
|
||||||
while queue:
|
while queue:
|
||||||
node_id, depth = queue.pop(0)
|
node_id, depth = queue.pop(0)
|
||||||
if node_id not in depths or depth > depths[node_id]:
|
if node_id in visited:
|
||||||
depths[node_id] = depth
|
continue # Skip already visited nodes (handles cycles)
|
||||||
|
visited.add(node_id)
|
||||||
|
depths[node_id] = depth
|
||||||
for child_id in children[node_id]:
|
for child_id in children[node_id]:
|
||||||
queue.append((child_id, depth + 1))
|
if child_id not in visited:
|
||||||
|
queue.append((child_id, depth + 1))
|
||||||
|
|
||||||
# Handle orphaned nodes (shouldn't happen but be safe)
|
# Handle orphaned nodes (shouldn't happen but be safe)
|
||||||
for f in features:
|
for f in features:
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -38,6 +38,13 @@ API_ENV_VARS = [
|
|||||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL",
|
"ANTHROPIC_DEFAULT_HAIKU_MODEL",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Feature MCP tools needed for expand session
|
||||||
|
EXPAND_FEATURE_TOOLS = [
|
||||||
|
"mcp__features__feature_create",
|
||||||
|
"mcp__features__feature_create_bulk",
|
||||||
|
"mcp__features__feature_get_stats",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]:
|
async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]:
|
||||||
"""
|
"""
|
||||||
@@ -61,9 +68,8 @@ class ExpandChatSession:
|
|||||||
|
|
||||||
Unlike SpecChatSession which writes spec files, this session:
|
Unlike SpecChatSession which writes spec files, this session:
|
||||||
1. Reads existing app_spec.txt for context
|
1. Reads existing app_spec.txt for context
|
||||||
2. Parses feature definitions from Claude's output
|
2. Chats with the user to define new features
|
||||||
3. Creates features via REST API
|
3. Claude creates features via the feature_create_bulk MCP tool
|
||||||
4. Tracks which features were created during the session
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, project_name: str, project_dir: Path):
|
def __init__(self, project_name: str, project_dir: Path):
|
||||||
@@ -171,6 +177,18 @@ class ExpandChatSession:
|
|||||||
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
|
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
|
||||||
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||||
|
|
||||||
|
# Build MCP servers config for feature creation
|
||||||
|
mcp_servers = {
|
||||||
|
"features": {
|
||||||
|
"command": sys.executable,
|
||||||
|
"args": ["-m", "mcp_server.feature_mcp"],
|
||||||
|
"env": {
|
||||||
|
"PROJECT_DIR": str(self.project_dir.resolve()),
|
||||||
|
"PYTHONPATH": str(ROOT_DIR.resolve()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Create Claude SDK client
|
# Create Claude SDK client
|
||||||
try:
|
try:
|
||||||
self.client = ClaudeSDKClient(
|
self.client = ClaudeSDKClient(
|
||||||
@@ -181,8 +199,10 @@ class ExpandChatSession:
|
|||||||
allowed_tools=[
|
allowed_tools=[
|
||||||
"Read",
|
"Read",
|
||||||
"Glob",
|
"Glob",
|
||||||
|
*EXPAND_FEATURE_TOOLS,
|
||||||
],
|
],
|
||||||
permission_mode="acceptEdits",
|
mcp_servers=mcp_servers,
|
||||||
|
permission_mode="bypassPermissions",
|
||||||
max_turns=100,
|
max_turns=100,
|
||||||
cwd=str(self.project_dir.resolve()),
|
cwd=str(self.project_dir.resolve()),
|
||||||
settings=str(settings_file.resolve()),
|
settings=str(settings_file.resolve()),
|
||||||
@@ -267,7 +287,8 @@ class ExpandChatSession:
|
|||||||
"""
|
"""
|
||||||
Internal method to query Claude and stream responses.
|
Internal method to query Claude and stream responses.
|
||||||
|
|
||||||
Handles text responses and detects feature creation blocks.
|
Feature creation is handled by Claude calling the feature_create_bulk
|
||||||
|
MCP tool directly -- no text parsing needed.
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
return
|
return
|
||||||
@@ -291,9 +312,6 @@ class ExpandChatSession:
|
|||||||
else:
|
else:
|
||||||
await self.client.query(message)
|
await self.client.query(message)
|
||||||
|
|
||||||
# Accumulate full response to detect feature blocks
|
|
||||||
full_response = ""
|
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
async for msg in self.client.receive_response():
|
async for msg in self.client.receive_response():
|
||||||
msg_type = type(msg).__name__
|
msg_type = type(msg).__name__
|
||||||
@@ -305,7 +323,6 @@ class ExpandChatSession:
|
|||||||
if block_type == "TextBlock" and hasattr(block, "text"):
|
if block_type == "TextBlock" and hasattr(block, "text"):
|
||||||
text = block.text
|
text = block.text
|
||||||
if text:
|
if text:
|
||||||
full_response += text
|
|
||||||
yield {"type": "text", "content": text}
|
yield {"type": "text", "content": text}
|
||||||
|
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
@@ -314,123 +331,6 @@ class ExpandChatSession:
|
|||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check for feature creation blocks in full response (handle multiple blocks)
|
|
||||||
features_matches = re.findall(
|
|
||||||
r'<features_to_create>\s*(\[[\s\S]*?\])\s*</features_to_create>',
|
|
||||||
full_response
|
|
||||||
)
|
|
||||||
|
|
||||||
if features_matches:
|
|
||||||
# Collect all features from all blocks, deduplicating by name
|
|
||||||
all_features: list[dict] = []
|
|
||||||
seen_names: set[str] = set()
|
|
||||||
|
|
||||||
for features_json in features_matches:
|
|
||||||
try:
|
|
||||||
features_data = json.loads(features_json)
|
|
||||||
|
|
||||||
if features_data and isinstance(features_data, list):
|
|
||||||
for feature in features_data:
|
|
||||||
name = feature.get("name", "")
|
|
||||||
if name and name not in seen_names:
|
|
||||||
seen_names.add(name)
|
|
||||||
all_features.append(feature)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Failed to parse features JSON block: {e}")
|
|
||||||
# Continue processing other blocks
|
|
||||||
|
|
||||||
if all_features:
|
|
||||||
try:
|
|
||||||
# Create all deduplicated features
|
|
||||||
created = await self._create_features_bulk(all_features)
|
|
||||||
|
|
||||||
if created:
|
|
||||||
self.features_created += len(created)
|
|
||||||
self.created_feature_ids.extend([f["id"] for f in created])
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"type": "features_created",
|
|
||||||
"count": len(created),
|
|
||||||
"features": created
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Created {len(created)} features for {self.project_name}")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to create features")
|
|
||||||
yield {
|
|
||||||
"type": "error",
|
|
||||||
"content": "Failed to create features"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _create_features_bulk(self, features: list[dict]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Create features directly in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: List of feature dictionaries with category, name, description, steps
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of created feature dictionaries with IDs
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Uses flush() to get IDs immediately without re-querying by priority range,
|
|
||||||
which could pick up rows from concurrent writers.
|
|
||||||
"""
|
|
||||||
# Import database classes
|
|
||||||
import sys
|
|
||||||
root = Path(__file__).parent.parent.parent
|
|
||||||
if str(root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(root))
|
|
||||||
|
|
||||||
from api.database import Feature, create_database
|
|
||||||
|
|
||||||
# Get database session
|
|
||||||
_, SessionLocal = create_database(self.project_dir)
|
|
||||||
session = SessionLocal()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Determine starting priority
|
|
||||||
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
|
|
||||||
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
|
|
||||||
|
|
||||||
created_rows: list = []
|
|
||||||
|
|
||||||
for f in features:
|
|
||||||
db_feature = Feature(
|
|
||||||
priority=current_priority,
|
|
||||||
category=f.get("category", "functional"),
|
|
||||||
name=f.get("name", "Unnamed feature"),
|
|
||||||
description=f.get("description", ""),
|
|
||||||
steps=f.get("steps", []),
|
|
||||||
passes=False,
|
|
||||||
in_progress=False,
|
|
||||||
)
|
|
||||||
session.add(db_feature)
|
|
||||||
created_rows.append(db_feature)
|
|
||||||
current_priority += 1
|
|
||||||
|
|
||||||
# Flush to get IDs without relying on priority range query
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Build result from the flushed objects (IDs are now populated)
|
|
||||||
created_features = [
|
|
||||||
{
|
|
||||||
"id": db_feature.id,
|
|
||||||
"name": db_feature.name,
|
|
||||||
"category": db_feature.category,
|
|
||||||
}
|
|
||||||
for db_feature in created_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return created_features
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def get_features_created(self) -> int:
|
def get_features_created(self) -> int:
|
||||||
"""Get the total number of features created in this session."""
|
"""Get the total number of features created in this session."""
|
||||||
return self.features_created
|
return self.features_created
|
||||||
|
|||||||
Reference in New Issue
Block a user