From d8a8c83447f28da5be4cb04b438cd97d4af733d7 Mon Sep 17 00:00:00 2001 From: Auto Date: Sun, 1 Feb 2026 09:45:20 +0200 Subject: [PATCH] fix: prevent SQLite corruption in parallel mode with atomic operations Replace ineffective threading.Lock() with atomic SQL operations for cross-process safety. Key changes: - Add SQLAlchemy event hooks (do_connect/do_begin) for BEGIN IMMEDIATE transactions in api/database.py - Add atomic_transaction() context manager for multi-statement ops - Convert all feature MCP write operations to atomic UPDATE...WHERE with compare-and-swap patterns (feature_claim, mark_passing, etc.) - Add WHERE passes=0 state guard to feature_mark_passing - Add WAL checkpoint on shutdown and idempotent cleanup() in parallel_orchestrator.py with async-safe signal handling - Wrap SQLite connections with contextlib.closing() in progress.py - Add thread-safe engine cache with double-checked locking in assistant_database.py - Migrate to SQLAlchemy 2.0 DeclarativeBase across all modules Inspired by PR #108 (cabana8471-arch), with fixes for nested BEGIN EXCLUSIVE bug and missing state guards. Closes #106 Co-Authored-By: Claude Opus 4.5 --- api/database.py | 125 +++++++-- mcp_server/feature_mcp.py | 362 ++++++++++++++------------ parallel_orchestrator.py | 105 +++++++- progress.py | 102 ++++---- registry.py | 7 +- server/services/assistant_database.py | 43 ++- 6 files changed, 489 insertions(+), 255 deletions(-) diff --git a/api/database.py b/api/database.py index 6dd4676..2a732fe 100644 --- a/api/database.py +++ b/api/database.py @@ -8,7 +8,7 @@ SQLite database schema for feature storage using SQLAlchemy. import sys from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Generator, Optional def _utc_now() -> datetime: @@ -26,13 +26,16 @@ from sqlalchemy import ( String, Text, create_engine, + event, text, ) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, relationship, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker from sqlalchemy.types import JSON -Base = declarative_base() + +class Base(DeclarativeBase): + """SQLAlchemy 2.0 style declarative base.""" + pass class Feature(Base): @@ -307,11 +310,11 @@ def _migrate_add_schedules_tables(engine) -> None: # Create schedules table if missing if "schedules" not in existing_tables: - Schedule.__table__.create(bind=engine) + Schedule.__table__.create(bind=engine) # type: ignore[attr-defined] # Create schedule_overrides table if missing if "schedule_overrides" not in existing_tables: - ScheduleOverride.__table__.create(bind=engine) + ScheduleOverride.__table__.create(bind=engine) # type: ignore[attr-defined] # Add crash_count column if missing (for upgrades) if "schedules" in existing_tables: @@ -332,6 +335,35 @@ def _migrate_add_schedules_tables(engine) -> None: conn.commit() +def _configure_sqlite_immediate_transactions(engine) -> None: + """Configure engine for IMMEDIATE transactions via event hooks. + + Per SQLAlchemy docs: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html + + This replaces fragile pysqlite implicit transaction handling with explicit + BEGIN IMMEDIATE at transaction start. Benefits: + - Acquires write lock immediately, preventing stale reads + - Works correctly regardless of prior ORM operations + - Future-proof: won't break when pysqlite legacy mode is removed in Python 3.16 + """ + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # Disable pysqlite's implicit transaction handling + dbapi_connection.isolation_level = None + + # Set busy_timeout on raw connection before any transactions + cursor = dbapi_connection.cursor() + try: + cursor.execute("PRAGMA busy_timeout=30000") + finally: + cursor.close() + + @event.listens_for(engine, "begin") + def do_begin(conn): + # Use IMMEDIATE for all transactions to prevent stale reads + conn.exec_driver_sql("BEGIN IMMEDIATE") + + def create_database(project_dir: Path) -> tuple: """ Create database and return engine + session maker. @@ -351,21 +383,37 @@ def create_database(project_dir: Path) -> tuple: return _engine_cache[cache_key] db_url = get_database_url(project_dir) - engine = create_engine(db_url, connect_args={ - "check_same_thread": False, - "timeout": 30 # Wait up to 30s for locks - }) - Base.metadata.create_all(bind=engine) # Choose journal mode based on filesystem type # WAL mode doesn't work reliably on network filesystems and can cause corruption is_network = _is_network_path(project_dir) journal_mode = "DELETE" if is_network else "WAL" + engine = create_engine(db_url, connect_args={ + "check_same_thread": False, + "timeout": 30 # Wait up to 30s for locks + }) + + # Set journal mode BEFORE configuring event hooks + # PRAGMA journal_mode must run outside of a transaction, and our event hooks + # start a transaction with BEGIN IMMEDIATE on every operation with engine.connect() as conn: - conn.execute(text(f"PRAGMA journal_mode={journal_mode}")) - conn.execute(text("PRAGMA busy_timeout=30000")) - conn.commit() + # Get raw DBAPI connection to execute PRAGMA outside transaction + raw_conn = conn.connection.dbapi_connection + if raw_conn is None: + raise RuntimeError("Failed to get raw DBAPI connection") + cursor = raw_conn.cursor() + try: + cursor.execute(f"PRAGMA journal_mode={journal_mode}") + cursor.execute("PRAGMA busy_timeout=30000") + finally: + cursor.close() + + # Configure IMMEDIATE transactions via event hooks AFTER setting PRAGMAs + # This must happen before create_all() and migrations run + _configure_sqlite_immediate_transactions(engine) + + Base.metadata.create_all(bind=engine) # Migrate existing databases _migrate_add_in_progress_column(engine) @@ -417,7 +465,7 @@ def set_session_maker(session_maker: sessionmaker) -> None: _session_maker = session_maker -def get_db() -> Session: +def get_db() -> Generator[Session, None, None]: """ Dependency for FastAPI to get database session. @@ -434,3 +482,50 @@ def get_db() -> Session: raise finally: db.close() + + +# ============================================================================= +# Atomic Transaction Helpers for Parallel Mode +# ============================================================================= +# These helpers prevent database corruption when multiple processes access the +# same SQLite database concurrently. They use IMMEDIATE transactions which +# acquire write locks at the start (preventing stale reads) and atomic +# UPDATE ... WHERE clauses (preventing check-then-modify races). + + +from contextlib import contextmanager + + +@contextmanager +def atomic_transaction(session_maker): + """Context manager for atomic SQLite transactions. + + Acquires a write lock immediately via BEGIN IMMEDIATE (configured by + engine event hooks), preventing stale reads in read-modify-write patterns. + This is essential for preventing race conditions in parallel mode. + + Args: + session_maker: SQLAlchemy sessionmaker + + Yields: + SQLAlchemy session with automatic commit/rollback + + Example: + with atomic_transaction(session_maker) as session: + # All reads in this block are protected by write lock + feature = session.query(Feature).filter(...).first() + feature.priority = new_priority + # Commit happens automatically on exit + """ + session = session_maker() + try: + yield session + session.commit() + except Exception: + try: + session.rollback() + except Exception: + pass # Don't let rollback failure mask original error + raise + finally: + session.close() diff --git a/mcp_server/feature_mcp.py b/mcp_server/feature_mcp.py index a394f1e..a7f2691 100755 --- a/mcp_server/feature_mcp.py +++ b/mcp_server/feature_mcp.py @@ -30,18 +30,18 @@ orchestrator, not by agents. Agents receive pre-assigned feature IDs. import json import os import sys -import threading from contextlib import asynccontextmanager from pathlib import Path from typing import Annotated from mcp.server.fastmcp import FastMCP from pydantic import BaseModel, Field +from sqlalchemy import text # Add parent directory to path so we can import from api module sys.path.insert(0, str(Path(__file__).parent.parent)) -from api.database import Feature, create_database +from api.database import Feature, atomic_transaction, create_database from api.dependency_resolver import ( MAX_DEPENDENCIES_PER_FEATURE, compute_scheduling_scores, @@ -96,8 +96,9 @@ class BulkCreateInput(BaseModel): _session_maker = None _engine = None -# Lock for priority assignment to prevent race conditions -_priority_lock = threading.Lock() +# NOTE: The old threading.Lock() was removed because it only worked per-process, +# not cross-process. In parallel mode, multiple MCP servers run in separate +# processes, so the lock was useless. We now use atomic SQL operations instead. @asynccontextmanager @@ -243,15 +244,25 @@ def feature_mark_passing( """ session = get_session() try: - feature = session.query(Feature).filter(Feature.id == feature_id).first() - - if feature is None: - return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - - feature.passes = True - feature.in_progress = False + # Atomic update with state guard - prevents double-pass in parallel mode + result = session.execute(text(""" + UPDATE features + SET passes = 1, in_progress = 0 + WHERE id = :id AND passes = 0 + """), {"id": feature_id}) session.commit() + if result.rowcount == 0: + # Check why the update didn't match + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if feature is None: + return json.dumps({"error": f"Feature with ID {feature_id} not found"}) + if feature.passes: + return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) + return json.dumps({"error": "Failed to mark feature passing for unknown reason"}) + + # Get the feature name for the response + feature = session.query(Feature).filter(Feature.id == feature_id).first() return json.dumps({"success": True, "feature_id": feature_id, "name": feature.name}) except Exception as e: session.rollback() @@ -284,14 +295,20 @@ def feature_mark_failing( """ session = get_session() try: + # Check if feature exists first feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - feature.passes = False - feature.in_progress = False + # Atomic update for parallel safety + session.execute(text(""" + UPDATE features + SET passes = 0, in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) session.commit() + + # Refresh to get updated state session.refresh(feature) return json.dumps({ @@ -337,25 +354,28 @@ def feature_skip( return json.dumps({"error": "Cannot skip a feature that is already passing"}) old_priority = feature.priority + name = feature.name - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get max priority and set this feature to max + 1 - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - new_priority = (max_priority_result[0] + 1) if max_priority_result else 1 - - feature.priority = new_priority - feature.in_progress = False - session.commit() + # Atomic update: set priority to max+1 in a single statement + # This prevents race conditions where two features get the same priority + session.execute(text(""" + UPDATE features + SET priority = (SELECT COALESCE(MAX(priority), 0) + 1 FROM features), + in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) + session.commit() + # Refresh to get new priority session.refresh(feature) + new_priority = feature.priority return json.dumps({ - "id": feature.id, - "name": feature.name, + "id": feature_id, + "name": name, "old_priority": old_priority, "new_priority": new_priority, - "message": f"Feature '{feature.name}' moved to end of queue" + "message": f"Feature '{name}' moved to end of queue" }) except Exception as e: session.rollback() @@ -381,21 +401,27 @@ def feature_mark_in_progress( """ session = get_session() try: - feature = session.query(Feature).filter(Feature.id == feature_id).first() - - if feature is None: - return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - - if feature.passes: - return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) - - if feature.in_progress: - return json.dumps({"error": f"Feature with ID {feature_id} is already in-progress"}) - - feature.in_progress = True + # Atomic claim: only succeeds if feature is not already claimed or passing + result = session.execute(text(""" + UPDATE features + SET in_progress = 1 + WHERE id = :id AND passes = 0 AND in_progress = 0 + """), {"id": feature_id}) session.commit() - session.refresh(feature) + if result.rowcount == 0: + # Check why the claim failed + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if feature is None: + return json.dumps({"error": f"Feature with ID {feature_id} not found"}) + if feature.passes: + return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) + if feature.in_progress: + return json.dumps({"error": f"Feature with ID {feature_id} is already in-progress"}) + return json.dumps({"error": "Failed to mark feature in-progress for unknown reason"}) + + # Fetch the claimed feature + feature = session.query(Feature).filter(Feature.id == feature_id).first() return json.dumps(feature.to_dict()) except Exception as e: session.rollback() @@ -421,24 +447,35 @@ def feature_claim_and_get( """ session = get_session() try: + # First check if feature exists feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) if feature.passes: return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) - # Idempotent: if already in-progress, just return details - already_claimed = feature.in_progress - if not already_claimed: - feature.in_progress = True - session.commit() - session.refresh(feature) + # Try atomic claim: only succeeds if not already claimed + result = session.execute(text(""" + UPDATE features + SET in_progress = 1 + WHERE id = :id AND passes = 0 AND in_progress = 0 + """), {"id": feature_id}) + session.commit() - result = feature.to_dict() - result["already_claimed"] = already_claimed - return json.dumps(result) + # Determine if we claimed it or it was already claimed + already_claimed = result.rowcount == 0 + if already_claimed: + # Verify it's in_progress (not some other failure condition) + session.refresh(feature) + if not feature.in_progress: + return json.dumps({"error": f"Failed to claim feature {feature_id} for unknown reason"}) + + # Refresh to get current state + session.refresh(feature) + result_dict = feature.to_dict() + result_dict["already_claimed"] = already_claimed + return json.dumps(result_dict) except Exception as e: session.rollback() return json.dumps({"error": f"Failed to claim feature: {str(e)}"}) @@ -463,15 +500,20 @@ def feature_clear_in_progress( """ session = get_session() try: + # Check if feature exists feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - feature.in_progress = False + # Atomic update - idempotent, safe in parallel mode + session.execute(text(""" + UPDATE features + SET in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) session.commit() - session.refresh(feature) + session.refresh(feature) return json.dumps(feature.to_dict()) except Exception as e: session.rollback() @@ -506,13 +548,14 @@ def feature_create_bulk( Returns: JSON with: created (int) - number of features created, with_dependencies (int) """ - session = get_session() try: - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get the starting priority - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - start_priority = (max_priority_result[0] + 1) if max_priority_result else 1 + # Use atomic transaction for bulk inserts to prevent priority conflicts + with atomic_transaction(_session_maker) as session: + # Get the starting priority atomically within the transaction + result = session.execute(text(""" + SELECT COALESCE(MAX(priority), 0) FROM features + """)).fetchone() + start_priority = (result[0] or 0) + 1 # First pass: validate all features and their index-based dependencies for i, feature_data in enumerate(features): @@ -546,7 +589,7 @@ def feature_create_bulk( "error": f"Feature at index {i} cannot depend on feature at index {idx} (forward reference not allowed)" }) - # Second pass: create all features + # Second pass: create all features with reserved priorities created_features: list[Feature] = [] for i, feature_data in enumerate(features): db_feature = Feature( @@ -574,17 +617,13 @@ def feature_create_bulk( created_features[i].dependencies = sorted(dep_ids) deps_count += 1 - session.commit() - - return json.dumps({ - "created": len(created_features), - "with_dependencies": deps_count - }) + # Commit happens automatically on context manager exit + return json.dumps({ + "created": len(created_features), + "with_dependencies": deps_count + }) except Exception as e: - session.rollback() return json.dumps({"error": str(e)}) - finally: - session.close() @mcp.tool() @@ -608,13 +647,14 @@ def feature_create( Returns: JSON with the created feature details including its ID """ - session = get_session() try: - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get the next priority - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - next_priority = (max_priority_result[0] + 1) if max_priority_result else 1 + # Use atomic transaction to prevent priority collisions + with atomic_transaction(_session_maker) as session: + # Get the next priority atomically within the transaction + result = session.execute(text(""" + SELECT COALESCE(MAX(priority), 0) + 1 FROM features + """)).fetchone() + next_priority = result[0] db_feature = Feature( priority=next_priority, @@ -626,20 +666,18 @@ def feature_create( in_progress=False, ) session.add(db_feature) - session.commit() + session.flush() # Get the ID - session.refresh(db_feature) + feature_dict = db_feature.to_dict() + # Commit happens automatically on context manager exit return json.dumps({ "success": True, "message": f"Created feature: {name}", - "feature": db_feature.to_dict() + "feature": feature_dict }) except Exception as e: - session.rollback() return json.dumps({"error": str(e)}) - finally: - session.close() @mcp.tool() @@ -659,52 +697,49 @@ def feature_add_dependency( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - # Security: Self-reference check + # Security: Self-reference check (can do before transaction) if feature_id == dependency_id: return json.dumps({"error": "A feature cannot depend on itself"}) - feature = session.query(Feature).filter(Feature.id == feature_id).first() - dependency = session.query(Feature).filter(Feature.id == dependency_id).first() + # Use atomic transaction for consistent cycle detection + with atomic_transaction(_session_maker) as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + dependency = session.query(Feature).filter(Feature.id == dependency_id).first() - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) - if not dependency: - return json.dumps({"error": f"Dependency feature {dependency_id} not found"}) + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) + if not dependency: + return json.dumps({"error": f"Dependency feature {dependency_id} not found"}) - current_deps = feature.dependencies or [] + current_deps = feature.dependencies or [] - # Security: Max dependencies limit - if len(current_deps) >= MAX_DEPENDENCIES_PER_FEATURE: - return json.dumps({"error": f"Maximum {MAX_DEPENDENCIES_PER_FEATURE} dependencies allowed per feature"}) + # Security: Max dependencies limit + if len(current_deps) >= MAX_DEPENDENCIES_PER_FEATURE: + return json.dumps({"error": f"Maximum {MAX_DEPENDENCIES_PER_FEATURE} dependencies allowed per feature"}) - # Check if already exists - if dependency_id in current_deps: - return json.dumps({"error": "Dependency already exists"}) + # Check if already exists + if dependency_id in current_deps: + return json.dumps({"error": "Dependency already exists"}) - # Security: Circular dependency check - # would_create_circular_dependency(features, source_id, target_id) - # source_id = feature gaining the dependency, target_id = feature being depended upon - all_features = [f.to_dict() for f in session.query(Feature).all()] - if would_create_circular_dependency(all_features, feature_id, dependency_id): - return json.dumps({"error": "Cannot add: would create circular dependency"}) + # Security: Circular dependency check + # Within IMMEDIATE transaction, snapshot is protected by write lock + all_features = [f.to_dict() for f in session.query(Feature).all()] + if would_create_circular_dependency(all_features, feature_id, dependency_id): + return json.dumps({"error": "Cannot add: would create circular dependency"}) - # Add dependency - current_deps.append(dependency_id) - feature.dependencies = sorted(current_deps) - session.commit() + # Add dependency atomically + new_deps = sorted(current_deps + [dependency_id]) + feature.dependencies = new_deps + # Commit happens automatically on context manager exit - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies - }) + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": new_deps + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to add dependency: {str(e)}"}) - finally: - session.close() @mcp.tool() @@ -721,30 +756,29 @@ def feature_remove_dependency( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - feature = session.query(Feature).filter(Feature.id == feature_id).first() - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) + # Use atomic transaction for consistent read-modify-write + with atomic_transaction(_session_maker) as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) - current_deps = feature.dependencies or [] - if dependency_id not in current_deps: - return json.dumps({"error": "Dependency does not exist"}) + current_deps = feature.dependencies or [] + if dependency_id not in current_deps: + return json.dumps({"error": "Dependency does not exist"}) - current_deps.remove(dependency_id) - feature.dependencies = current_deps if current_deps else None - session.commit() + # Remove dependency atomically + new_deps = [d for d in current_deps if d != dependency_id] + feature.dependencies = new_deps if new_deps else None + # Commit happens automatically on context manager exit - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies or [] - }) + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": new_deps + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to remove dependency: {str(e)}"}) - finally: - session.close() @mcp.tool() @@ -897,9 +931,8 @@ def feature_set_dependencies( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - # Security: Self-reference check + # Security: Self-reference check (can do before transaction) if feature_id in dependency_ids: return json.dumps({"error": "A feature cannot depend on itself"}) @@ -911,45 +944,44 @@ def feature_set_dependencies( if len(dependency_ids) != len(set(dependency_ids)): return json.dumps({"error": "Duplicate dependencies not allowed"}) - feature = session.query(Feature).filter(Feature.id == feature_id).first() - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) + # Use atomic transaction for consistent cycle detection + with atomic_transaction(_session_maker) as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) - # Validate all dependencies exist - all_feature_ids = {f.id for f in session.query(Feature).all()} - missing = [d for d in dependency_ids if d not in all_feature_ids] - if missing: - return json.dumps({"error": f"Dependencies not found: {missing}"}) + # Validate all dependencies exist + all_feature_ids = {f.id for f in session.query(Feature).all()} + missing = [d for d in dependency_ids if d not in all_feature_ids] + if missing: + return json.dumps({"error": f"Dependencies not found: {missing}"}) - # Check for circular dependencies - all_features = [f.to_dict() for f in session.query(Feature).all()] - # Temporarily update the feature's dependencies for cycle check - test_features = [] - for f in all_features: - if f["id"] == feature_id: - test_features.append({**f, "dependencies": dependency_ids}) - else: - test_features.append(f) + # Check for circular dependencies + # Within IMMEDIATE transaction, snapshot is protected by write lock + all_features = [f.to_dict() for f in session.query(Feature).all()] + test_features = [] + for f in all_features: + if f["id"] == feature_id: + test_features.append({**f, "dependencies": dependency_ids}) + else: + test_features.append(f) - for dep_id in dependency_ids: - # source_id = feature_id (gaining dep), target_id = dep_id (being depended upon) - if would_create_circular_dependency(test_features, feature_id, dep_id): - return json.dumps({"error": f"Cannot add dependency {dep_id}: would create circular dependency"}) + for dep_id in dependency_ids: + if would_create_circular_dependency(test_features, feature_id, dep_id): + return json.dumps({"error": f"Cannot add dependency {dep_id}: would create circular dependency"}) - # Set dependencies - feature.dependencies = sorted(dependency_ids) if dependency_ids else None - session.commit() + # Set dependencies atomically + sorted_deps = sorted(dependency_ids) if dependency_ids else None + feature.dependencies = sorted_deps + # Commit happens automatically on context manager exit - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies or [] - }) + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": sorted_deps or [] + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to set dependencies: {str(e)}"}) - finally: - session.close() if __name__ == "__main__": diff --git a/parallel_orchestrator.py b/parallel_orchestrator.py index 574cbd2..6e8bb54 100644 --- a/parallel_orchestrator.py +++ b/parallel_orchestrator.py @@ -19,7 +19,9 @@ Usage: """ import asyncio +import atexit import os +import signal import subprocess import sys import threading @@ -27,6 +29,8 @@ from datetime import datetime, timezone from pathlib import Path from typing import Callable, Literal +from sqlalchemy import text + from api.database import Feature, create_database from api.dependency_resolver import are_dependencies_satisfied, compute_scheduling_scores from progress import has_features @@ -139,11 +143,11 @@ class ParallelOrchestrator: self, project_dir: Path, max_concurrency: int = DEFAULT_CONCURRENCY, - model: str = None, + model: str | None = None, yolo_mode: bool = False, testing_agent_ratio: int = 1, - on_output: Callable[[int, str], None] = None, - on_status: Callable[[int, str], None] = None, + on_output: Callable[[int, str], None] | None = None, + on_status: Callable[[int, str], None] | None = None, ): """Initialize the orchestrator. @@ -182,14 +186,18 @@ class ParallelOrchestrator: # Track feature failures to prevent infinite retry loops self._failure_counts: dict[int, int] = {} + # Shutdown flag for async-safe signal handling + # Signal handlers only set this flag; cleanup happens in the main loop + self._shutdown_requested = False + # Session tracking for logging/debugging - self.session_start_time: datetime = None + self.session_start_time: datetime | None = None # Event signaled when any agent completes, allowing the main loop to wake # immediately instead of waiting for the full POLL_INTERVAL timeout. # This reduces latency when spawning the next feature after completion. - self._agent_completed_event: asyncio.Event = None # Created in run_loop - self._event_loop: asyncio.AbstractEventLoop = None # Stored for thread-safe signaling + self._agent_completed_event: asyncio.Event | None = None # Created in run_loop + self._event_loop: asyncio.AbstractEventLoop | None = None # Stored for thread-safe signaling # Database session for this orchestrator self._engine, self._session_maker = create_database(project_dir) @@ -375,7 +383,8 @@ class ParallelOrchestrator: session = self.get_session() try: session.expire_all() - return session.query(Feature).filter(Feature.passes == True).count() + count: int = session.query(Feature).filter(Feature.passes == True).count() + return count finally: session.close() @@ -511,11 +520,14 @@ class ParallelOrchestrator: try: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), # Run from autocoder root for proper imports "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -546,7 +558,7 @@ class ParallelOrchestrator: daemon=True ).start() - if self.on_status: + if self.on_status is not None: self.on_status(feature_id, "running") print(f"Started coding agent for feature #{feature_id}", flush=True) @@ -600,11 +612,14 @@ class ParallelOrchestrator: try: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -658,11 +673,14 @@ class ParallelOrchestrator: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -682,7 +700,7 @@ class ParallelOrchestrator: if not line: break print(line.rstrip(), flush=True) - if self.on_output: + if self.on_output is not None: self.on_output(0, line.rstrip()) # Use 0 as feature_id for initializer proc.wait() @@ -716,11 +734,14 @@ class ParallelOrchestrator: ): """Read output from subprocess and emit events.""" try: + if proc.stdout is None: + proc.wait() + return for line in proc.stdout: if abort.is_set(): break line = line.rstrip() - if self.on_output: + if self.on_output is not None: self.on_output(feature_id or 0, line) else: # Both coding and testing agents now use [Feature #X] format @@ -814,6 +835,9 @@ class ParallelOrchestrator: self._signal_agent_completed() return + # feature_id is required for coding agents (always passed from start_feature) + assert feature_id is not None, "feature_id must not be None for coding agents" + # Coding agent completion debug_log.log("COMPLETE", f"Coding agent for feature #{feature_id} finished", return_code=return_code, @@ -855,7 +879,7 @@ class ParallelOrchestrator: failure_count=failure_count) status = "completed" if return_code == 0 else "failed" - if self.on_status: + if self.on_status is not None: self.on_status(feature_id, status) # CRITICAL: This print triggers the WebSocket to emit agent_update with state='error' or 'success' print(f"Feature #{feature_id} {status}", flush=True) @@ -1014,7 +1038,7 @@ class ParallelOrchestrator: debug_log.section("FEATURE LOOP STARTING") loop_iteration = 0 - while self.is_running: + while self.is_running and not self._shutdown_requested: loop_iteration += 1 if loop_iteration <= 3: print(f"[DEBUG] === Loop iteration {loop_iteration} ===", flush=True) @@ -1163,11 +1187,40 @@ class ParallelOrchestrator: "yolo_mode": self.yolo_mode, } + def cleanup(self) -> None: + """Clean up database resources. Safe to call multiple times. + + Forces WAL checkpoint to flush pending writes to main database file, + then disposes engine to close all connections. Prevents stale cache + issues when the orchestrator restarts. + """ + # Atomically grab and clear the engine reference to prevent re-entry + engine = self._engine + self._engine = None + + if engine is None: + return # Already cleaned up + + try: + debug_log.log("CLEANUP", "Forcing WAL checkpoint before dispose") + with engine.connect() as conn: + conn.execute(text("PRAGMA wal_checkpoint(FULL)")) + conn.commit() + debug_log.log("CLEANUP", "WAL checkpoint completed, disposing engine") + except Exception as e: + debug_log.log("CLEANUP", f"WAL checkpoint failed (non-fatal): {e}") + + try: + engine.dispose() + debug_log.log("CLEANUP", "Engine disposed successfully") + except Exception as e: + debug_log.log("CLEANUP", f"Engine dispose failed: {e}") + async def run_parallel_orchestrator( project_dir: Path, max_concurrency: int = DEFAULT_CONCURRENCY, - model: str = None, + model: str | None = None, yolo_mode: bool = False, testing_agent_ratio: int = 1, ) -> None: @@ -1189,11 +1242,37 @@ async def run_parallel_orchestrator( testing_agent_ratio=testing_agent_ratio, ) + # Set up cleanup to run on exit (handles normal exit, exceptions) + def cleanup_handler(): + debug_log.log("CLEANUP", "atexit cleanup handler invoked") + orchestrator.cleanup() + + atexit.register(cleanup_handler) + + # Set up async-safe signal handler for graceful shutdown + # Only sets flags - everything else is unsafe in signal context + def signal_handler(signum, frame): + orchestrator._shutdown_requested = True + orchestrator.is_running = False + + # Register SIGTERM handler for process termination signals + # Note: On Windows, SIGTERM handlers only fire from os.kill() calls within Python. + # External termination (Task Manager, taskkill, Popen.terminate()) uses + # TerminateProcess() which bypasses signal handlers entirely. + signal.signal(signal.SIGTERM, signal_handler) + + # Note: We intentionally do NOT register SIGINT handler + # Let Python raise KeyboardInterrupt naturally so the except block works + try: await orchestrator.run_loop() except KeyboardInterrupt: print("\n\nInterrupted by user. Stopping agents...", flush=True) orchestrator.stop_all() + finally: + # CRITICAL: Always clean up database resources on exit + # This forces WAL checkpoint and disposes connections + orchestrator.cleanup() def main(): diff --git a/progress.py b/progress.py index 0821c90..1f17ae6 100644 --- a/progress.py +++ b/progress.py @@ -10,12 +10,21 @@ import json import os import sqlite3 import urllib.request +from contextlib import closing from datetime import datetime, timezone from pathlib import Path WEBHOOK_URL = os.environ.get("PROGRESS_N8N_WEBHOOK_URL") PROGRESS_CACHE_FILE = ".progress_cache" +# SQLite connection settings for parallel mode safety +SQLITE_TIMEOUT = 30 # seconds to wait for locks + + +def _get_connection(db_file: Path) -> sqlite3.Connection: + """Get a SQLite connection with proper timeout settings for parallel mode.""" + return sqlite3.connect(db_file, timeout=SQLITE_TIMEOUT) + def has_features(project_dir: Path) -> bool: """ @@ -31,8 +40,6 @@ def has_features(project_dir: Path) -> bool: Returns False if no features exist (initializer needs to run). """ - import sqlite3 - # Check legacy JSON file first json_file = project_dir / "feature_list.json" if json_file.exists(): @@ -44,12 +51,11 @@ def has_features(project_dir: Path) -> bool: return False try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) FROM features") - count = cursor.fetchone()[0] - conn.close() - return count > 0 + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM features") + count: int = cursor.fetchone()[0] + return bool(count > 0) except Exception: # Database exists but can't be read or has no features table return False @@ -70,36 +76,35 @@ def count_passing_tests(project_dir: Path) -> tuple[int, int, int]: return 0, 0, 0 try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - # Single aggregate query instead of 3 separate COUNT queries - # Handle case where in_progress column doesn't exist yet (legacy DBs) - try: - cursor.execute(""" - SELECT - COUNT(*) as total, - SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing, - SUM(CASE WHEN in_progress = 1 THEN 1 ELSE 0 END) as in_progress - FROM features - """) - row = cursor.fetchone() - total = row[0] or 0 - passing = row[1] or 0 - in_progress = row[2] or 0 - except sqlite3.OperationalError: - # Fallback for databases without in_progress column - cursor.execute(""" - SELECT - COUNT(*) as total, - SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing - FROM features - """) - row = cursor.fetchone() - total = row[0] or 0 - passing = row[1] or 0 - in_progress = 0 - conn.close() - return passing, in_progress, total + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + # Single aggregate query instead of 3 separate COUNT queries + # Handle case where in_progress column doesn't exist yet (legacy DBs) + try: + cursor.execute(""" + SELECT + COUNT(*) as total, + SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing, + SUM(CASE WHEN in_progress = 1 THEN 1 ELSE 0 END) as in_progress + FROM features + """) + row = cursor.fetchone() + total = row[0] or 0 + passing = row[1] or 0 + in_progress = row[2] or 0 + except sqlite3.OperationalError: + # Fallback for databases without in_progress column + cursor.execute(""" + SELECT + COUNT(*) as total, + SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing + FROM features + """) + row = cursor.fetchone() + total = row[0] or 0 + passing = row[1] or 0 + in_progress = 0 + return passing, in_progress, total except Exception as e: print(f"[Database error in count_passing_tests: {e}]") return 0, 0, 0 @@ -120,17 +125,16 @@ def get_all_passing_features(project_dir: Path) -> list[dict]: return [] try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - cursor.execute( - "SELECT id, category, name FROM features WHERE passes = 1 ORDER BY priority ASC" - ) - features = [ - {"id": row[0], "category": row[1], "name": row[2]} - for row in cursor.fetchall() - ] - conn.close() - return features + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id, category, name FROM features WHERE passes = 1 ORDER BY priority ASC" + ) + features = [ + {"id": row[0], "category": row[1], "name": row[2]} + for row in cursor.fetchall() + ] + return features except Exception: return [] diff --git a/registry.py b/registry.py index f84803e..7d0c2af 100644 --- a/registry.py +++ b/registry.py @@ -17,8 +17,7 @@ from pathlib import Path from typing import Any from sqlalchemy import Column, DateTime, Integer, String, create_engine, text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import DeclarativeBase, sessionmaker # Module logger logger = logging.getLogger(__name__) @@ -75,7 +74,9 @@ class RegistryPermissionDenied(RegistryError): # SQLAlchemy Model # ============================================================================= -Base = declarative_base() +class Base(DeclarativeBase): + """SQLAlchemy 2.0 style declarative base.""" + pass class Project(Base): diff --git a/server/services/assistant_database.py b/server/services/assistant_database.py index f2ade75..0dbfdd3 100644 --- a/server/services/assistant_database.py +++ b/server/services/assistant_database.py @@ -7,21 +7,28 @@ Each project has its own assistant.db file in the project directory. """ import logging +import threading from datetime import datetime, timezone from pathlib import Path from typing import Optional from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, create_engine, func -from sqlalchemy.orm import declarative_base, relationship, sessionmaker +from sqlalchemy.orm import DeclarativeBase, relationship, sessionmaker logger = logging.getLogger(__name__) -Base = declarative_base() +class Base(DeclarativeBase): + """SQLAlchemy 2.0 style declarative base.""" + pass # Engine cache to avoid creating new engines for each request # Key: project directory path (as posix string), Value: SQLAlchemy engine _engine_cache: dict[str, object] = {} +# Lock for thread-safe access to the engine cache +# Prevents race conditions when multiple threads create engines simultaneously +_cache_lock = threading.Lock() + def _utc_now() -> datetime: """Return current UTC time. Replacement for deprecated datetime.utcnow().""" @@ -64,17 +71,33 @@ def get_engine(project_dir: Path): Uses a cache to avoid creating new engines for each request, which improves performance by reusing database connections. + + Thread-safe: Uses a lock to prevent race conditions when multiple threads + try to create engines simultaneously for the same project. """ cache_key = project_dir.as_posix() - if cache_key not in _engine_cache: - db_path = get_db_path(project_dir) - # Use as_posix() for cross-platform compatibility with SQLite connection strings - db_url = f"sqlite:///{db_path.as_posix()}" - engine = create_engine(db_url, echo=False) - Base.metadata.create_all(engine) - _engine_cache[cache_key] = engine - logger.debug(f"Created new database engine for {cache_key}") + # Double-checked locking for thread safety and performance + if cache_key in _engine_cache: + return _engine_cache[cache_key] + + with _cache_lock: + # Check again inside the lock in case another thread created it + if cache_key not in _engine_cache: + db_path = get_db_path(project_dir) + # Use as_posix() for cross-platform compatibility with SQLite connection strings + db_url = f"sqlite:///{db_path.as_posix()}" + engine = create_engine( + db_url, + echo=False, + connect_args={ + "check_same_thread": False, + "timeout": 30, # Wait up to 30s for locks + } + ) + Base.metadata.create_all(engine) + _engine_cache[cache_key] = engine + logger.debug(f"Created new database engine for {cache_key}") return _engine_cache[cache_key]