diff --git a/api/database.py b/api/database.py index e7bcf46..271b7cc 100644 --- a/api/database.py +++ b/api/database.py @@ -6,11 +6,27 @@ SQLite database schema for feature storage using SQLAlchemy. """ import sys -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import Optional -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, create_engine, text + +def _utc_now() -> datetime: + """Return current UTC time. Replacement for deprecated _utc_now().""" + return datetime.now(timezone.utc) + +from sqlalchemy import ( + Boolean, + CheckConstraint, + Column, + DateTime, + ForeignKey, + Integer, + String, + Text, + create_engine, + text, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, relationship, sessionmaker from sqlalchemy.types import JSON @@ -65,6 +81,14 @@ class Schedule(Base): __tablename__ = "schedules" + # Database-level CHECK constraints for data integrity + __table_args__ = ( + CheckConstraint('duration_minutes >= 1 AND duration_minutes <= 1440', name='ck_schedule_duration'), + CheckConstraint('days_of_week >= 0 AND days_of_week <= 127', name='ck_schedule_days'), + CheckConstraint('max_concurrency >= 1 AND max_concurrency <= 5', name='ck_schedule_concurrency'), + CheckConstraint('crash_count >= 0', name='ck_schedule_crash_count'), + ) + id = Column(Integer, primary_key=True, index=True) project_name = Column(String(50), nullable=False, index=True) @@ -87,7 +111,7 @@ class Schedule(Base): crash_count = Column(Integer, nullable=False, default=0) # Resets at window start # Metadata - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) + created_at = Column(DateTime, nullable=False, default=_utc_now) # Relationships overrides = relationship( @@ -131,7 +155,7 @@ class ScheduleOverride(Base): expires_at = Column(DateTime, nullable=False) # When this window ends (UTC) # Metadata - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) + created_at = Column(DateTime, nullable=False, default=_utc_now) # Relationships schedule = relationship("Schedule", back_populates="overrides") diff --git a/progress.py b/progress.py index dfb700b..a4dda26 100644 --- a/progress.py +++ b/progress.py @@ -10,7 +10,7 @@ import json import os import sqlite3 import urllib.request -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path WEBHOOK_URL = os.environ.get("PROGRESS_N8N_WEBHOOK_URL") @@ -171,7 +171,7 @@ def send_progress_webhook(passing: int, total: int, project_dir: Path) -> None: "tests_completed_this_session": passing - previous, "completed_tests": completed_tests, "project": project_dir.name, - "timestamp": datetime.utcnow().isoformat() + "Z", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), } try: diff --git a/server/routers/schedules.py b/server/routers/schedules.py index 6138824..7c6c4ed 100644 --- a/server/routers/schedules.py +++ b/server/routers/schedules.py @@ -8,10 +8,16 @@ Provides CRUD operations for time-based schedule configuration. import re import sys +from contextlib import contextmanager from datetime import datetime, timedelta, timezone from pathlib import Path +from typing import Generator, Tuple from fastapi import APIRouter, HTTPException +from sqlalchemy.orm import Session + +# Schedule limits to prevent resource exhaustion +MAX_SCHEDULES_PER_PROJECT = 50 from ..schemas import ( NextRunResponse, @@ -48,8 +54,15 @@ def validate_project_name(name: str) -> str: return name -def _get_db_session(project_name: str): - """Get database session for a project.""" +@contextmanager +def _get_db_session(project_name: str) -> Generator[Tuple[Session, Path], None, None]: + """Get database session for a project as a context manager. + + Usage: + with _get_db_session(project_name) as (db, project_path): + # ... use db ... + # db is automatically closed + """ from api.database import create_database project_name = validate_project_name(project_name) @@ -68,7 +81,11 @@ def _get_db_session(project_name: str): ) _, SessionLocal = create_database(project_path) - return SessionLocal(), project_path + db = SessionLocal() + try: + yield db, project_path + finally: + db.close() @router.get("", response_model=ScheduleListResponse) @@ -76,9 +93,7 @@ async def list_schedules(project_name: str): """Get all schedules for a project.""" from api.database import Schedule - db, _ = _get_db_session(project_name) - - try: + with _get_db_session(project_name) as (db, _): schedules = db.query(Schedule).filter( Schedule.project_name == project_name ).order_by(Schedule.start_time).all() @@ -100,8 +115,6 @@ async def list_schedules(project_name: str): for s in schedules ] ) - finally: - db.close() @router.post("", response_model=ScheduleResponse, status_code=201) @@ -111,9 +124,18 @@ async def create_schedule(project_name: str, data: ScheduleCreate): from ..services.scheduler_service import get_scheduler - db, project_path = _get_db_session(project_name) + with _get_db_session(project_name) as (db, project_path): + # Check schedule limit to prevent resource exhaustion + existing_count = db.query(Schedule).filter( + Schedule.project_name == project_name + ).count() + + if existing_count >= MAX_SCHEDULES_PER_PROJECT: + raise HTTPException( + status_code=400, + detail=f"Maximum schedules per project ({MAX_SCHEDULES_PER_PROJECT}) exceeded" + ) - try: # Create schedule record schedule = Schedule( project_name=project_name, @@ -178,9 +200,6 @@ async def create_schedule(project_name: str, data: ScheduleCreate): created_at=schedule.created_at, ) - finally: - db.close() - @router.get("/next", response_model=NextRunResponse) async def get_next_scheduled_run(project_name: str): @@ -189,9 +208,7 @@ async def get_next_scheduled_run(project_name: str): from ..services.scheduler_service import get_scheduler - db, _ = _get_db_session(project_name) - - try: + with _get_db_session(project_name) as (db, _): schedules = db.query(Schedule).filter( Schedule.project_name == project_name, Schedule.enabled == True, # noqa: E712 @@ -245,18 +262,13 @@ async def get_next_scheduled_run(project_name: str): active_schedule_count=active_count, ) - finally: - db.close() - @router.get("/{schedule_id}", response_model=ScheduleResponse) async def get_schedule(project_name: str, schedule_id: int): """Get a single schedule by ID.""" from api.database import Schedule - db, _ = _get_db_session(project_name) - - try: + with _get_db_session(project_name) as (db, _): schedule = db.query(Schedule).filter( Schedule.id == schedule_id, Schedule.project_name == project_name, @@ -278,9 +290,6 @@ async def get_schedule(project_name: str, schedule_id: int): created_at=schedule.created_at, ) - finally: - db.close() - @router.patch("/{schedule_id}", response_model=ScheduleResponse) async def update_schedule( @@ -293,9 +302,7 @@ async def update_schedule( from ..services.scheduler_service import get_scheduler - db, project_path = _get_db_session(project_name) - - try: + with _get_db_session(project_name) as (db, project_path): schedule = db.query(Schedule).filter( Schedule.id == schedule_id, Schedule.project_name == project_name, @@ -337,9 +344,6 @@ async def update_schedule( created_at=schedule.created_at, ) - finally: - db.close() - @router.delete("/{schedule_id}", status_code=204) async def delete_schedule(project_name: str, schedule_id: int): @@ -348,9 +352,7 @@ async def delete_schedule(project_name: str, schedule_id: int): from ..services.scheduler_service import get_scheduler - db, _ = _get_db_session(project_name) - - try: + with _get_db_session(project_name) as (db, _): schedule = db.query(Schedule).filter( Schedule.id == schedule_id, Schedule.project_name == project_name, @@ -367,9 +369,6 @@ async def delete_schedule(project_name: str, schedule_id: int): db.delete(schedule) db.commit() - finally: - db.close() - def _calculate_window_end(schedule, now: datetime) -> datetime: """Calculate when the current window ends.""" diff --git a/server/services/assistant_database.py b/server/services/assistant_database.py index 3c5ee44..176768e 100644 --- a/server/services/assistant_database.py +++ b/server/services/assistant_database.py @@ -7,7 +7,7 @@ Each project has its own assistant.db file in the project directory. """ import logging -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -19,6 +19,11 @@ logger = logging.getLogger(__name__) Base = declarative_base() +def _utc_now() -> datetime: + """Return current UTC time. Replacement for deprecated datetime.utcnow().""" + return datetime.now(timezone.utc) + + class Conversation(Base): """A conversation with the assistant for a project.""" __tablename__ = "conversations" @@ -26,8 +31,8 @@ class Conversation(Base): id = Column(Integer, primary_key=True, index=True) project_name = Column(String(100), nullable=False, index=True) title = Column(String(200), nullable=True) # Optional title, derived from first message - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at = Column(DateTime, default=_utc_now) + updated_at = Column(DateTime, default=_utc_now, onupdate=_utc_now) messages = relationship("ConversationMessage", back_populates="conversation", cascade="all, delete-orphan") @@ -40,7 +45,7 @@ class ConversationMessage(Base): conversation_id = Column(Integer, ForeignKey("conversations.id"), nullable=False, index=True) role = Column(String(20), nullable=False) # "user" | "assistant" | "system" content = Column(Text, nullable=False) - timestamp = Column(DateTime, default=datetime.utcnow) + timestamp = Column(DateTime, default=_utc_now) conversation = relationship("Conversation", back_populates="messages") @@ -175,7 +180,7 @@ def add_message(project_dir: Path, conversation_id: int, role: str, content: str session.add(message) # Update conversation's updated_at timestamp - conversation.updated_at = datetime.utcnow() + conversation.updated_at = _utc_now() # Auto-generate title from first user message if not set if not conversation.title and role == "user": diff --git a/server/services/scheduler_service.py b/server/services/scheduler_service.py index d6fc1b6..eb22a3a 100644 --- a/server/services/scheduler_service.py +++ b/server/services/scheduler_service.py @@ -368,6 +368,14 @@ class SchedulerService: logger.info(f"Agent already running for {project_name}, skipping scheduled start") return + # Register crash callback to enable auto-restart during scheduled windows + async def on_status_change(status: str): + if status == "crashed": + logger.info(f"Crash detected for {project_name}, attempting recovery") + await self.handle_crash_during_window(project_name, project_dir) + + manager.add_status_callback(on_status_change) + logger.info( f"Starting agent for {project_name} " f"(schedule {schedule.id}, yolo={schedule.yolo_mode}, concurrency={schedule.max_concurrency})" @@ -382,6 +390,8 @@ class SchedulerService: logger.info(f"✓ Agent started successfully for {project_name}") else: logger.error(f"✗ Failed to start agent for {project_name}: {msg}") + # Remove callback if start failed + manager.remove_status_callback(on_status_change) async def _stop_agent(self, project_name: str, project_dir: Path): """Stop the agent for a project.""" @@ -457,7 +467,10 @@ class SchedulerService: def _create_override_for_active_schedules( self, project_name: str, project_dir: Path, override_type: str ): - """Create overrides for all active schedule windows.""" + """Create overrides for all active schedule windows. + + Uses atomic delete-then-create pattern to prevent race conditions. + """ from api.database import Schedule, ScheduleOverride, create_database try: @@ -479,17 +492,20 @@ class SchedulerService: # Calculate window end time window_end = self._calculate_window_end(schedule, now) - # Check if override already exists - existing = db.query(ScheduleOverride).filter( + # Atomic operation: delete any existing overrides of this type + # and create a new one in the same transaction + deleted = db.query(ScheduleOverride).filter( ScheduleOverride.schedule_id == schedule.id, ScheduleOverride.override_type == override_type, - ScheduleOverride.expires_at > now, - ).first() + ).delete() - if existing: - continue + if deleted: + logger.debug( + f"Removed {deleted} existing '{override_type}' override(s) " + f"for schedule {schedule.id}" + ) - # Create override + # Create new override override = ScheduleOverride( schedule_id=schedule.id, override_type=override_type, diff --git a/ui/src/components/ScheduleModal.tsx b/ui/src/components/ScheduleModal.tsx index 0d54865..8c412a5 100644 --- a/ui/src/components/ScheduleModal.tsx +++ b/ui/src/components/ScheduleModal.tsx @@ -14,8 +14,9 @@ import { useToggleSchedule, } from '../hooks/useSchedules' import { - utcToLocal, - localToUTC, + utcToLocalWithDayShift, + localToUTCWithDayShift, + adjustDaysForDayShift, formatDuration, DAYS, isDayActive, @@ -109,10 +110,18 @@ export function ScheduleModal({ projectName, isOpen, onClose }: ScheduleModalPro return } - // Convert local time to UTC + // Convert local time to UTC and get day shift + const { time: utcTime, dayShift } = localToUTCWithDayShift(newSchedule.start_time) + + // Adjust days_of_week based on day shift + // If UTC is on the next day (dayShift = 1), shift days forward + // If UTC is on the previous day (dayShift = -1), shift days backward + const adjustedDays = adjustDaysForDayShift(newSchedule.days_of_week, dayShift) + const scheduleToCreate = { ...newSchedule, - start_time: localToUTC(newSchedule.start_time), + start_time: utcTime, + days_of_week: adjustedDays, } await createSchedule.mutateAsync(scheduleToCreate) @@ -203,8 +212,12 @@ export function ScheduleModal({ projectName, isOpen, onClose }: ScheduleModalPro {!isLoading && schedules.length > 0 && (