fix: wire MCP server into ExpandChatSession for feature creation

Replace direct-DB feature creation with MCP tool path. The expand
session now configures the feature MCP server and allows
feature_create_bulk tool calls, matching how AssistantChatSession
already works. Removes duplicated _create_features_bulk() method
and <features_to_create> regex parsing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rudiheydra
2026-01-29 10:03:07 +11:00
parent d68d70c800
commit 3161c1260a

View File

@@ -10,8 +10,8 @@ import asyncio
import json
import logging
import os
import re
import shutil
import sys
import threading
import uuid
from datetime import datetime
@@ -38,6 +38,13 @@ API_ENV_VARS = [
"ANTHROPIC_DEFAULT_HAIKU_MODEL",
]
# Feature MCP tools needed for expand session
EXPAND_FEATURE_TOOLS = [
"mcp__features__feature_create",
"mcp__features__feature_create_bulk",
"mcp__features__feature_get_stats",
]
async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]:
"""
@@ -61,9 +68,8 @@ class ExpandChatSession:
Unlike SpecChatSession which writes spec files, this session:
1. Reads existing app_spec.txt for context
2. Parses feature definitions from Claude's output
3. Creates features via REST API
4. Tracks which features were created during the session
2. Chats with the user to define new features
3. Claude creates features via the feature_create_bulk MCP tool
"""
def __init__(self, project_name: str, project_dir: Path):
@@ -171,6 +177,18 @@ class ExpandChatSession:
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
# Build MCP servers config for feature creation
mcp_servers = {
"features": {
"command": sys.executable,
"args": ["-m", "mcp_server.feature_mcp"],
"env": {
"PROJECT_DIR": str(self.project_dir.resolve()),
"PYTHONPATH": str(ROOT_DIR.resolve()),
},
},
}
# Create Claude SDK client
try:
self.client = ClaudeSDKClient(
@@ -181,8 +199,10 @@ class ExpandChatSession:
allowed_tools=[
"Read",
"Glob",
*EXPAND_FEATURE_TOOLS,
],
permission_mode="acceptEdits",
mcp_servers=mcp_servers,
permission_mode="bypassPermissions",
max_turns=100,
cwd=str(self.project_dir.resolve()),
settings=str(settings_file.resolve()),
@@ -267,7 +287,8 @@ class ExpandChatSession:
"""
Internal method to query Claude and stream responses.
Handles text responses and detects feature creation blocks.
Feature creation is handled by Claude calling the feature_create_bulk
MCP tool directly -- no text parsing needed.
"""
if not self.client:
return
@@ -291,9 +312,6 @@ class ExpandChatSession:
else:
await self.client.query(message)
# Accumulate full response to detect feature blocks
full_response = ""
# Stream the response
async for msg in self.client.receive_response():
msg_type = type(msg).__name__
@@ -305,7 +323,6 @@ class ExpandChatSession:
if block_type == "TextBlock" and hasattr(block, "text"):
text = block.text
if text:
full_response += text
yield {"type": "text", "content": text}
self.messages.append({
@@ -314,123 +331,6 @@ class ExpandChatSession:
"timestamp": datetime.now().isoformat()
})
# Check for feature creation blocks in full response (handle multiple blocks)
features_matches = re.findall(
r'<features_to_create>\s*(\[[\s\S]*?\])\s*</features_to_create>',
full_response
)
if features_matches:
# Collect all features from all blocks, deduplicating by name
all_features: list[dict] = []
seen_names: set[str] = set()
for features_json in features_matches:
try:
features_data = json.loads(features_json)
if features_data and isinstance(features_data, list):
for feature in features_data:
name = feature.get("name", "")
if name and name not in seen_names:
seen_names.add(name)
all_features.append(feature)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse features JSON block: {e}")
# Continue processing other blocks
if all_features:
try:
# Create all deduplicated features
created = await self._create_features_bulk(all_features)
if created:
self.features_created += len(created)
self.created_feature_ids.extend([f["id"] for f in created])
yield {
"type": "features_created",
"count": len(created),
"features": created
}
logger.info(f"Created {len(created)} features for {self.project_name}")
except Exception:
logger.exception("Failed to create features")
yield {
"type": "error",
"content": "Failed to create features"
}
async def _create_features_bulk(self, features: list[dict]) -> list[dict]:
"""
Create features directly in the database.
Args:
features: List of feature dictionaries with category, name, description, steps
Returns:
List of created feature dictionaries with IDs
Note:
Uses flush() to get IDs immediately without re-querying by priority range,
which could pick up rows from concurrent writers.
"""
# Import database classes
import sys
root = Path(__file__).parent.parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from api.database import Feature, create_database
# Get database session
_, SessionLocal = create_database(self.project_dir)
session = SessionLocal()
try:
# Determine starting priority
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
created_rows: list = []
for f in features:
db_feature = Feature(
priority=current_priority,
category=f.get("category", "functional"),
name=f.get("name", "Unnamed feature"),
description=f.get("description", ""),
steps=f.get("steps", []),
passes=False,
in_progress=False,
)
session.add(db_feature)
created_rows.append(db_feature)
current_priority += 1
# Flush to get IDs without relying on priority range query
session.flush()
# Build result from the flushed objects (IDs are now populated)
created_features = [
{
"id": db_feature.id,
"name": db_feature.name,
"category": db_feature.category,
}
for db_feature in created_rows
]
session.commit()
return created_features
except Exception:
session.rollback()
raise
finally:
session.close()
def get_features_created(self) -> int:
"""Get the total number of features created in this session."""
return self.features_created