fix: wire MCP server into ExpandChatSession for feature creation

Replace direct-DB feature creation with MCP tool path. The expand
session now configures the feature MCP server and allows
feature_create_bulk tool calls, matching how AssistantChatSession
already works. Removes duplicated _create_features_bulk() method
and <features_to_create> regex parsing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rudiheydra
2026-01-29 10:03:07 +11:00
parent d68d70c800
commit 3161c1260a

View File

@@ -10,8 +10,8 @@ import asyncio
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import sys
import threading import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
@@ -38,6 +38,13 @@ API_ENV_VARS = [
"ANTHROPIC_DEFAULT_HAIKU_MODEL", "ANTHROPIC_DEFAULT_HAIKU_MODEL",
] ]
# Feature MCP tools needed for expand session
EXPAND_FEATURE_TOOLS = [
"mcp__features__feature_create",
"mcp__features__feature_create_bulk",
"mcp__features__feature_get_stats",
]
async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]: async def _make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]:
""" """
@@ -61,9 +68,8 @@ class ExpandChatSession:
Unlike SpecChatSession which writes spec files, this session: Unlike SpecChatSession which writes spec files, this session:
1. Reads existing app_spec.txt for context 1. Reads existing app_spec.txt for context
2. Parses feature definitions from Claude's output 2. Chats with the user to define new features
3. Creates features via REST API 3. Claude creates features via the feature_create_bulk MCP tool
4. Tracks which features were created during the session
""" """
def __init__(self, project_name: str, project_dir: Path): def __init__(self, project_name: str, project_dir: Path):
@@ -171,6 +177,18 @@ class ExpandChatSession:
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names # This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
# Build MCP servers config for feature creation
mcp_servers = {
"features": {
"command": sys.executable,
"args": ["-m", "mcp_server.feature_mcp"],
"env": {
"PROJECT_DIR": str(self.project_dir.resolve()),
"PYTHONPATH": str(ROOT_DIR.resolve()),
},
},
}
# Create Claude SDK client # Create Claude SDK client
try: try:
self.client = ClaudeSDKClient( self.client = ClaudeSDKClient(
@@ -181,8 +199,10 @@ class ExpandChatSession:
allowed_tools=[ allowed_tools=[
"Read", "Read",
"Glob", "Glob",
*EXPAND_FEATURE_TOOLS,
], ],
permission_mode="acceptEdits", mcp_servers=mcp_servers,
permission_mode="bypassPermissions",
max_turns=100, max_turns=100,
cwd=str(self.project_dir.resolve()), cwd=str(self.project_dir.resolve()),
settings=str(settings_file.resolve()), settings=str(settings_file.resolve()),
@@ -267,7 +287,8 @@ class ExpandChatSession:
""" """
Internal method to query Claude and stream responses. Internal method to query Claude and stream responses.
Handles text responses and detects feature creation blocks. Feature creation is handled by Claude calling the feature_create_bulk
MCP tool directly -- no text parsing needed.
""" """
if not self.client: if not self.client:
return return
@@ -291,9 +312,6 @@ class ExpandChatSession:
else: else:
await self.client.query(message) await self.client.query(message)
# Accumulate full response to detect feature blocks
full_response = ""
# Stream the response # Stream the response
async for msg in self.client.receive_response(): async for msg in self.client.receive_response():
msg_type = type(msg).__name__ msg_type = type(msg).__name__
@@ -305,7 +323,6 @@ class ExpandChatSession:
if block_type == "TextBlock" and hasattr(block, "text"): if block_type == "TextBlock" and hasattr(block, "text"):
text = block.text text = block.text
if text: if text:
full_response += text
yield {"type": "text", "content": text} yield {"type": "text", "content": text}
self.messages.append({ self.messages.append({
@@ -314,123 +331,6 @@ class ExpandChatSession:
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
}) })
# Check for feature creation blocks in full response (handle multiple blocks)
features_matches = re.findall(
r'<features_to_create>\s*(\[[\s\S]*?\])\s*</features_to_create>',
full_response
)
if features_matches:
# Collect all features from all blocks, deduplicating by name
all_features: list[dict] = []
seen_names: set[str] = set()
for features_json in features_matches:
try:
features_data = json.loads(features_json)
if features_data and isinstance(features_data, list):
for feature in features_data:
name = feature.get("name", "")
if name and name not in seen_names:
seen_names.add(name)
all_features.append(feature)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse features JSON block: {e}")
# Continue processing other blocks
if all_features:
try:
# Create all deduplicated features
created = await self._create_features_bulk(all_features)
if created:
self.features_created += len(created)
self.created_feature_ids.extend([f["id"] for f in created])
yield {
"type": "features_created",
"count": len(created),
"features": created
}
logger.info(f"Created {len(created)} features for {self.project_name}")
except Exception:
logger.exception("Failed to create features")
yield {
"type": "error",
"content": "Failed to create features"
}
async def _create_features_bulk(self, features: list[dict]) -> list[dict]:
"""
Create features directly in the database.
Args:
features: List of feature dictionaries with category, name, description, steps
Returns:
List of created feature dictionaries with IDs
Note:
Uses flush() to get IDs immediately without re-querying by priority range,
which could pick up rows from concurrent writers.
"""
# Import database classes
import sys
root = Path(__file__).parent.parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from api.database import Feature, create_database
# Get database session
_, SessionLocal = create_database(self.project_dir)
session = SessionLocal()
try:
# Determine starting priority
max_priority_feature = session.query(Feature).order_by(Feature.priority.desc()).first()
current_priority = (max_priority_feature.priority + 1) if max_priority_feature else 1
created_rows: list = []
for f in features:
db_feature = Feature(
priority=current_priority,
category=f.get("category", "functional"),
name=f.get("name", "Unnamed feature"),
description=f.get("description", ""),
steps=f.get("steps", []),
passes=False,
in_progress=False,
)
session.add(db_feature)
created_rows.append(db_feature)
current_priority += 1
# Flush to get IDs without relying on priority range query
session.flush()
# Build result from the flushed objects (IDs are now populated)
created_features = [
{
"id": db_feature.id,
"name": db_feature.name,
"category": db_feature.category,
}
for db_feature in created_rows
]
session.commit()
return created_features
except Exception:
session.rollback()
raise
finally:
session.close()
def get_features_created(self) -> int: def get_features_created(self) -> int:
"""Get the total number of features created in this session.""" """Get the total number of features created in this session."""
return self.features_created return self.features_created