mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-03-17 02:43:09 +00:00
feat: add API provider selection UI and fix stuck features on agent crash
API Provider Selection: - Add provider switcher in Settings modal (Claude, Kimi, GLM, Ollama, Custom) - Auth tokens stored locally only (registry.db), never returned by API - get_effective_sdk_env() builds provider-specific env vars for agent subprocess - All chat sessions (spec, expand, assistant) use provider settings - Backward compatible: defaults to Claude, env vars still work as override Fix Stuck Features: - Add _cleanup_stale_features() to process_manager.py - Reset in_progress features when agent stops, crashes, or fails healthcheck - Prevents features from being permanently stuck after rate limit crashes - Uses separate SQLAlchemy engine to avoid session conflicts with subprocess Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import sys
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ..schemas import ModelInfo, ModelsResponse, SettingsResponse, SettingsUpdate
|
||||
from ..schemas import ModelInfo, ModelsResponse, ProviderInfo, ProvidersResponse, SettingsResponse, SettingsUpdate
|
||||
from ..services.chat_constants import ROOT_DIR
|
||||
|
||||
# Mimetype fix for Windows - must run before StaticFiles is mounted
|
||||
@@ -23,9 +23,11 @@ if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
|
||||
from registry import (
|
||||
API_PROVIDERS,
|
||||
AVAILABLE_MODELS,
|
||||
DEFAULT_MODEL,
|
||||
get_all_settings,
|
||||
get_setting,
|
||||
set_setting,
|
||||
)
|
||||
|
||||
@@ -50,13 +52,40 @@ def _is_ollama_mode() -> bool:
|
||||
return "localhost:11434" in base_url or "127.0.0.1:11434" in base_url
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ProvidersResponse)
|
||||
async def get_available_providers():
|
||||
"""Get list of available API providers."""
|
||||
current = get_setting("api_provider", "claude") or "claude"
|
||||
providers = []
|
||||
for pid, pdata in API_PROVIDERS.items():
|
||||
providers.append(ProviderInfo(
|
||||
id=pid,
|
||||
name=pdata["name"],
|
||||
base_url=pdata.get("base_url"),
|
||||
models=[ModelInfo(id=m["id"], name=m["name"]) for m in pdata.get("models", [])],
|
||||
default_model=pdata.get("default_model", ""),
|
||||
requires_auth=pdata.get("requires_auth", False),
|
||||
))
|
||||
return ProvidersResponse(providers=providers, current=current)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def get_available_models():
|
||||
"""Get list of available models.
|
||||
|
||||
Frontend should call this to get the current list of models
|
||||
instead of hardcoding them.
|
||||
Returns models for the currently selected API provider.
|
||||
"""
|
||||
current_provider = get_setting("api_provider", "claude") or "claude"
|
||||
provider = API_PROVIDERS.get(current_provider)
|
||||
|
||||
if provider and current_provider != "claude":
|
||||
provider_models = provider.get("models", [])
|
||||
return ModelsResponse(
|
||||
models=[ModelInfo(id=m["id"], name=m["name"]) for m in provider_models],
|
||||
default=provider.get("default_model", ""),
|
||||
)
|
||||
|
||||
# Default: return Claude models
|
||||
return ModelsResponse(
|
||||
models=[ModelInfo(id=m["id"], name=m["name"]) for m in AVAILABLE_MODELS],
|
||||
default=DEFAULT_MODEL,
|
||||
@@ -85,14 +114,24 @@ async def get_settings():
|
||||
"""Get current global settings."""
|
||||
all_settings = get_all_settings()
|
||||
|
||||
api_provider = all_settings.get("api_provider", "claude")
|
||||
|
||||
# Compute glm_mode / ollama_mode from api_provider for backward compat
|
||||
glm_mode = api_provider == "glm" or _is_glm_mode()
|
||||
ollama_mode = api_provider == "ollama" or _is_ollama_mode()
|
||||
|
||||
return SettingsResponse(
|
||||
yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")),
|
||||
model=all_settings.get("model", DEFAULT_MODEL),
|
||||
glm_mode=_is_glm_mode(),
|
||||
ollama_mode=_is_ollama_mode(),
|
||||
glm_mode=glm_mode,
|
||||
ollama_mode=ollama_mode,
|
||||
testing_agent_ratio=_parse_int(all_settings.get("testing_agent_ratio"), 1),
|
||||
playwright_headless=_parse_bool(all_settings.get("playwright_headless"), default=True),
|
||||
batch_size=_parse_int(all_settings.get("batch_size"), 3),
|
||||
api_provider=api_provider,
|
||||
api_base_url=all_settings.get("api_base_url"),
|
||||
api_has_auth_token=bool(all_settings.get("api_auth_token")),
|
||||
api_model=all_settings.get("api_model"),
|
||||
)
|
||||
|
||||
|
||||
@@ -114,14 +153,47 @@ async def update_settings(update: SettingsUpdate):
|
||||
if update.batch_size is not None:
|
||||
set_setting("batch_size", str(update.batch_size))
|
||||
|
||||
# API provider settings
|
||||
if update.api_provider is not None:
|
||||
old_provider = get_setting("api_provider", "claude")
|
||||
set_setting("api_provider", update.api_provider)
|
||||
|
||||
# When provider changes, auto-set defaults for the new provider
|
||||
if update.api_provider != old_provider:
|
||||
provider = API_PROVIDERS.get(update.api_provider)
|
||||
if provider:
|
||||
# Auto-set base URL from provider definition
|
||||
if provider.get("base_url"):
|
||||
set_setting("api_base_url", provider["base_url"])
|
||||
# Auto-set model to provider's default
|
||||
if provider.get("default_model") and update.api_model is None:
|
||||
set_setting("api_model", provider["default_model"])
|
||||
|
||||
if update.api_base_url is not None:
|
||||
set_setting("api_base_url", update.api_base_url)
|
||||
|
||||
if update.api_auth_token is not None:
|
||||
set_setting("api_auth_token", update.api_auth_token)
|
||||
|
||||
if update.api_model is not None:
|
||||
set_setting("api_model", update.api_model)
|
||||
|
||||
# Return updated settings
|
||||
all_settings = get_all_settings()
|
||||
api_provider = all_settings.get("api_provider", "claude")
|
||||
glm_mode = api_provider == "glm" or _is_glm_mode()
|
||||
ollama_mode = api_provider == "ollama" or _is_ollama_mode()
|
||||
|
||||
return SettingsResponse(
|
||||
yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")),
|
||||
model=all_settings.get("model", DEFAULT_MODEL),
|
||||
glm_mode=_is_glm_mode(),
|
||||
ollama_mode=_is_ollama_mode(),
|
||||
glm_mode=glm_mode,
|
||||
ollama_mode=ollama_mode,
|
||||
testing_agent_ratio=_parse_int(all_settings.get("testing_agent_ratio"), 1),
|
||||
playwright_headless=_parse_bool(all_settings.get("playwright_headless"), default=True),
|
||||
batch_size=_parse_int(all_settings.get("batch_size"), 3),
|
||||
api_provider=api_provider,
|
||||
api_base_url=all_settings.get("api_base_url"),
|
||||
api_has_auth_token=bool(all_settings.get("api_auth_token")),
|
||||
api_model=all_settings.get("api_model"),
|
||||
)
|
||||
|
||||
@@ -391,6 +391,22 @@ class ModelInfo(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an API provider."""
|
||||
id: str
|
||||
name: str
|
||||
base_url: str | None = None
|
||||
models: list[ModelInfo]
|
||||
default_model: str
|
||||
requires_auth: bool = False
|
||||
|
||||
|
||||
class ProvidersResponse(BaseModel):
|
||||
"""Response schema for available providers list."""
|
||||
providers: list[ProviderInfo]
|
||||
current: str
|
||||
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
"""Response schema for global settings."""
|
||||
yolo_mode: bool = False
|
||||
@@ -400,6 +416,10 @@ class SettingsResponse(BaseModel):
|
||||
testing_agent_ratio: int = 1 # Regression testing agents (0-3)
|
||||
playwright_headless: bool = True
|
||||
batch_size: int = 3 # Features per coding agent batch (1-3)
|
||||
api_provider: str = "claude"
|
||||
api_base_url: str | None = None
|
||||
api_has_auth_token: bool = False # Never expose actual token
|
||||
api_model: str | None = None
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
@@ -415,12 +435,21 @@ class SettingsUpdate(BaseModel):
|
||||
testing_agent_ratio: int | None = None # 0-3
|
||||
playwright_headless: bool | None = None
|
||||
batch_size: int | None = None # Features per agent batch (1-3)
|
||||
api_provider: str | None = None
|
||||
api_base_url: str | None = None
|
||||
api_auth_token: str | None = None # Write-only, never returned
|
||||
api_model: str | None = None
|
||||
|
||||
@field_validator('model')
|
||||
@classmethod
|
||||
def validate_model(cls, v: str | None) -> str | None:
|
||||
if v is not None and v not in VALID_MODELS:
|
||||
raise ValueError(f"Invalid model. Must be one of: {VALID_MODELS}")
|
||||
def validate_model(cls, v: str | None, info) -> str | None: # type: ignore[override]
|
||||
if v is not None:
|
||||
# Skip VALID_MODELS check when using an alternative API provider
|
||||
api_provider = info.data.get("api_provider")
|
||||
if api_provider and api_provider != "claude":
|
||||
return v
|
||||
if v not in VALID_MODELS:
|
||||
raise ValueError(f"Invalid model. Must be one of: {VALID_MODELS}")
|
||||
return v
|
||||
|
||||
@field_validator('testing_agent_ratio')
|
||||
|
||||
@@ -258,15 +258,11 @@ class AssistantChatSession:
|
||||
system_cli = shutil.which("claude")
|
||||
|
||||
# Build environment overrides for API configuration
|
||||
sdk_env: dict[str, str] = {}
|
||||
for var in API_ENV_VARS:
|
||||
value = os.getenv(var)
|
||||
if value:
|
||||
sdk_env[var] = value
|
||||
from registry import get_effective_sdk_env
|
||||
sdk_env = get_effective_sdk_env()
|
||||
|
||||
# Determine model from environment or use default
|
||||
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
|
||||
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
# Determine model from SDK env (provider-aware) or fallback to env/default
|
||||
model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
|
||||
try:
|
||||
logger.info("Creating ClaudeSDKClient...")
|
||||
|
||||
@@ -154,16 +154,11 @@ class ExpandChatSession:
|
||||
system_prompt = skill_content.replace("$ARGUMENTS", project_path)
|
||||
|
||||
# Build environment overrides for API configuration
|
||||
# Filter to only include vars that are actually set (non-None)
|
||||
sdk_env: dict[str, str] = {}
|
||||
for var in API_ENV_VARS:
|
||||
value = os.getenv(var)
|
||||
if value:
|
||||
sdk_env[var] = value
|
||||
from registry import get_effective_sdk_env
|
||||
sdk_env = get_effective_sdk_env()
|
||||
|
||||
# Determine model from environment or use default
|
||||
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
|
||||
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
# Determine model from SDK env (provider-aware) or fallback to env/default
|
||||
model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
|
||||
# Build MCP servers config for feature creation
|
||||
mcp_servers = {
|
||||
|
||||
@@ -227,6 +227,46 @@ class AgentProcessManager:
|
||||
"""Remove lock file."""
|
||||
self.lock_file.unlink(missing_ok=True)
|
||||
|
||||
def _cleanup_stale_features(self) -> None:
|
||||
"""Clear in_progress flag for all features when agent stops/crashes.
|
||||
|
||||
When the agent process exits (normally or crash), any features left
|
||||
with in_progress=True were being worked on and didn't complete.
|
||||
Reset them so they can be picked up on next agent start.
|
||||
"""
|
||||
try:
|
||||
from autoforge_paths import get_features_db_path
|
||||
features_db = get_features_db_path(self.project_dir)
|
||||
if not features_db.exists():
|
||||
return
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from api.database import Feature
|
||||
|
||||
engine = create_engine(f"sqlite:///{features_db}")
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
try:
|
||||
stuck = session.query(Feature).filter(
|
||||
Feature.in_progress == True, # noqa: E712
|
||||
Feature.passes == False, # noqa: E712
|
||||
).all()
|
||||
if stuck:
|
||||
for f in stuck:
|
||||
f.in_progress = False
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Cleaned up %d stuck feature(s) for %s",
|
||||
len(stuck), self.project_name,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup features for %s: %s", self.project_name, e)
|
||||
|
||||
async def _broadcast_output(self, line: str) -> None:
|
||||
"""Broadcast output line to all registered callbacks."""
|
||||
with self._callbacks_lock:
|
||||
@@ -288,6 +328,7 @@ class AgentProcessManager:
|
||||
self.status = "crashed"
|
||||
elif self.status == "running":
|
||||
self.status = "stopped"
|
||||
self._cleanup_stale_features()
|
||||
self._remove_lock()
|
||||
|
||||
async def start(
|
||||
@@ -359,12 +400,22 @@ class AgentProcessManager:
|
||||
# stdin=DEVNULL prevents blocking if Claude CLI or child process tries to read stdin
|
||||
# CREATE_NO_WINDOW on Windows prevents console window pop-ups
|
||||
# PYTHONUNBUFFERED ensures output isn't delayed
|
||||
# Build subprocess environment with API provider settings
|
||||
from registry import get_effective_sdk_env
|
||||
api_env = get_effective_sdk_env()
|
||||
subprocess_env = {
|
||||
**os.environ,
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PLAYWRIGHT_HEADLESS": "true" if playwright_headless else "false",
|
||||
**api_env,
|
||||
}
|
||||
|
||||
popen_kwargs: dict[str, Any] = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.STDOUT,
|
||||
"cwd": str(self.project_dir),
|
||||
"env": {**os.environ, "PYTHONUNBUFFERED": "1", "PLAYWRIGHT_HEADLESS": "true" if playwright_headless else "false"},
|
||||
"env": subprocess_env,
|
||||
}
|
||||
if sys.platform == "win32":
|
||||
popen_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
|
||||
@@ -425,6 +476,7 @@ class AgentProcessManager:
|
||||
result.children_terminated, result.children_killed
|
||||
)
|
||||
|
||||
self._cleanup_stale_features()
|
||||
self._remove_lock()
|
||||
self.status = "stopped"
|
||||
self.process = None
|
||||
@@ -502,6 +554,7 @@ class AgentProcessManager:
|
||||
if poll is not None:
|
||||
# Process has terminated
|
||||
if self.status in ("running", "paused"):
|
||||
self._cleanup_stale_features()
|
||||
self.status = "crashed"
|
||||
self._remove_lock()
|
||||
return False
|
||||
|
||||
@@ -140,16 +140,11 @@ class SpecChatSession:
|
||||
system_cli = shutil.which("claude")
|
||||
|
||||
# Build environment overrides for API configuration
|
||||
# Filter to only include vars that are actually set (non-None)
|
||||
sdk_env: dict[str, str] = {}
|
||||
for var in API_ENV_VARS:
|
||||
value = os.getenv(var)
|
||||
if value:
|
||||
sdk_env[var] = value
|
||||
from registry import get_effective_sdk_env
|
||||
sdk_env = get_effective_sdk_env()
|
||||
|
||||
# Determine model from environment or use default
|
||||
# This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names
|
||||
model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
# Determine model from SDK env (provider-aware) or fallback to env/default
|
||||
model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101")
|
||||
|
||||
try:
|
||||
self.client = ClaudeSDKClient(
|
||||
|
||||
Reference in New Issue
Block a user