Merge pull request #124 from rudiheydra/fix/expand-session-mcp-and-scheduling

fix: expand session MCP wiring + scheduling infinite loop
This commit is contained in:
Leon van Zyl
2026-01-29 08:00:27 +02:00
committed by GitHub
2 changed files with 35 additions and 130 deletions

View File

@@ -300,15 +300,20 @@ def compute_scheduling_scores(features: list[dict]) -> dict[int, float]:
parents[f["id"]].append(dep_id) parents[f["id"]].append(dep_id)
# Calculate depths via BFS from roots # Calculate depths via BFS from roots
# Use visited set to prevent infinite loops from circular dependencies
depths: dict[int, int] = {} depths: dict[int, int] = {}
visited: set[int] = set()
roots = [f["id"] for f in features if not parents[f["id"]]] roots = [f["id"] for f in features if not parents[f["id"]]]
queue = [(root, 0) for root in roots] queue = [(root, 0) for root in roots]
while queue: while queue:
node_id, depth = queue.pop(0) node_id, depth = queue.pop(0)
if node_id not in depths or depth > depths[node_id]: if node_id in visited:
depths[node_id] = depth continue # Skip already visited nodes (handles cycles)
visited.add(node_id)
depths[node_id] = depth
for child_id in children[node_id]: for child_id in children[node_id]:
queue.append((child_id, depth + 1)) if child_id not in visited:
queue.append((child_id, depth + 1))
# Handle orphaned nodes (shouldn't happen but be safe) # Handle orphaned nodes (shouldn't happen but be safe)
for f in features: for f in features:

View File

@@ -10,8 +10,8 @@ import asyncio
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import sys
import threading import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
@@ -38,6 +38,13 @@ API_ENV_VARS = [
"ANTHROPIC_DEFAULT_HAIKU_MODEL", "ANTHROPIC_DEFAULT_HAIKU_MODEL",
] ]
# Feature MCP tools needed for expand session
EXPAND_FEATURE_TOOLS = [
"mcp__features__feature_create",
"mcp__features__feature_create_bulk",
"mcp__features__feature_get_stats",
]
async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]: async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]:
""" """
@@ -61,9 +68,8 @@ class ExpandChatSession:
Unlike SpecChatSession which writes spec files, this session: Unlike SpecChatSession which writes spec files, this session:
1. Reads existing app_spec.txt for context 1. Reads existing app_spec.txt for context
2. Parses feature definitions from Claude's output 2. Chats with the user to define new features
3. Creates features via REST API 3. Claude creates features via the feature_create_bulk MCP tool
4. Tracks which features were created during the session
""" """
def __init__(self, project_name: str, project_dir: Path): def __init__(self, project_name: str, project_dir: Path):
@@ -171,6 +177,18 @@ class ExpandChatSession:
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names # This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
# Build MCP servers config for feature creation
mcp_servers = {
"features": {
"command": sys.executable,
"args": ["-m", "mcp_server.feature_mcp"],
"env": {
"PROJECT_DIR": str(self.project_dir.resolve()),
"PYTHONPATH": str(ROOT_DIR.resolve()),
},
},
}
# Create Claude SDK client # Create Claude SDK client
try: try:
self.client = ClaudeSDKClient( self.client = ClaudeSDKClient(
@@ -181,8 +199,10 @@ class ExpandChatSession:
allowed_tools=[ allowed_tools=[
"Read", "Read",
"Glob", "Glob",
*EXPAND_FEATURE_TOOLS,
], ],
permission_mode="acceptEdits", mcp_servers=mcp_servers,
permission_mode="bypassPermissions",
max_turns=100, max_turns=100,
cwd=str(self.project_dir.resolve()), cwd=str(self.project_dir.resolve()),
settings=str(settings_file.resolve()), settings=str(settings_file.resolve()),
@@ -267,7 +287,8 @@ class ExpandChatSession:
""" """
Internal method to query Claude and stream responses. Internal method to query Claude and stream responses.
Handles text responses and detects feature creation blocks. Feature creation is handled by Claude calling the feature_create_bulk
MCP tool directly -- no text parsing needed.
""" """
if not self.client: if not self.client:
return return
@@ -291,9 +312,6 @@ class ExpandChatSession:
else: else:
await self.client.query(message) await self.client.query(message)
# Accumulate full response to detect feature blocks
full_response = ""
# Stream the response # Stream the response
async for msg in self.client.receive_response(): async for msg in self.client.receive_response():
msg_type = type(msg).__name__ msg_type = type(msg).__name__
@@ -305,7 +323,6 @@ class ExpandChatSession:
if block_type == "TextBlock" and hasattr(block, "text"): if block_type == "TextBlock" and hasattr(block, "text"):
text = block.text text = block.text
if text: if text:
full_response += text
yield {"type": "text", "content": text} yield {"type": "text", "content": text}
self.messages.append({ self.messages.append({
@@ -314,123 +331,6 @@ class ExpandChatSession:
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
}) })
# Check for feature creation blocks in full response (handle multiple blocks)
features_matches = re.findall(
r'<features_to_create>\s*(\[[\s\S]*?\])\s*</features_to_create>',
full_response
)
if features_matches:
# Collect all features from all blocks, deduplicating by name
all_features: list[dict] = []
seen_names: set[str] = set()
for features_json in features_matches:
try:
features_data = json.loads(features_json)
if features_data and isinstance(features_data, list):
for feature in features_data:
name = feature.get("name", "")
if name and name not in seen_names:
seen_names.add(name)
all_features.append(feature)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse features JSON block: {e}")
# Continue processing other blocks
if all_features:
try:
# Create all deduplicated features
created = await self._create_features_bulk(all_features)
if created:
self.features_created += len(created)
self.created_feature_ids.extend([f["id"] for f in created])
yield {
"type": "features_created",
"count": len(created),
"features": created
}
logger.info(f"Created {len(created)} features for {self.project_name}")
except Exception:
logger.exception("Failed to create features")
yield {
"type": "error",
"content": "Failed to create features"
}
async def _create_features_bulk(self, features: list[dict]) -> list[dict]:
"""
Create features directly in the database.
Args:
features: List of feature dictionaries with category, name, description, steps
Returns:
List of created feature dictionaries with IDs
Note:
Uses flush() to get IDs immediately without re-querying by priority range,
which could pick up rows from concurrent writers.
"""
# Import database classes
import sys
root = Path(__file__).parent.parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from api.database import Feature, create_database
# Get database session
_, SessionLocal = create_database(self.project_dir)
session = SessionLocal()
try:
# Determine starting priority
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
created_rows: list = []
for f in features:
db_feature = Feature(
priority=current_priority,
category=f.get("category", "functional"),
name=f.get("name", "Unnamed feature"),
description=f.get("description", ""),
steps=f.get("steps", []),
passes=False,
in_progress=False,
)
session.add(db_feature)
created_rows.append(db_feature)
current_priority += 1
# Flush to get IDs without relying on priority range query
session.flush()
# Build result from the flushed objects (IDs are now populated)
created_features = [
{
"id": db_feature.id,
"name": db_feature.name,
"category": db_feature.category,
}
for db_feature in created_rows
]
session.commit()
return created_features
except Exception:
session.rollback()
raise
finally:
session.close()
def get_features_created(self) -> int: def get_features_created(self) -> int:
"""Get the total number of features created in this session.""" """Get the total number of features created in this session."""
return self.features_created return self.features_created