diff --git a/apps/server/src/providers/provider-factory.ts b/apps/server/src/providers/provider-factory.ts index ed784a90..7da470e8 100644 --- a/apps/server/src/providers/provider-factory.ts +++ b/apps/server/src/providers/provider-factory.ts @@ -1,17 +1,43 @@ /** * Provider Factory - Routes model IDs to the appropriate provider * - * This factory implements model-based routing to automatically select - * the correct provider based on the model string. This makes adding - * new providers (Cursor, OpenCode, etc.) trivial - just add one line. + * Uses a registry pattern for dynamic provider registration. + * Providers register themselves on import, making it easy to add new providers. */ import { BaseProvider } from './base-provider.js'; -import { ClaudeProvider } from './claude-provider.js'; -import { CursorProvider } from './cursor-provider.js'; import type { InstallationStatus, ModelDefinition } from './types.js'; import { CURSOR_MODEL_MAP, type ModelProvider } from '@automaker/types'; +/** + * Provider registration entry + */ +interface ProviderRegistration { + /** Factory function to create provider instance */ + factory: () => BaseProvider; + /** Aliases for this provider (e.g., 'anthropic' for 'claude') */ + aliases?: string[]; + /** Function to check if this provider can handle a model ID */ + canHandleModel?: (modelId: string) => boolean; + /** Priority for model matching (higher = checked first) */ + priority?: number; +} + +/** + * Provider registry - stores registered providers + */ +const providerRegistry = new Map(); + +/** + * Register a provider with the factory + * + * @param name Provider name (e.g., 'claude', 'cursor') + * @param registration Provider registration config + */ +export function registerProvider(name: string, registration: ProviderRegistration): void { + providerRegistry.set(name.toLowerCase(), registration); +} + export class ProviderFactory { /** * Determine which provider to use for a given model @@ -22,26 +48,26 @@ export class ProviderFactory { static getProviderNameForModel(model: string): ModelProvider { const lowerModel = model.toLowerCase(); - // Check for explicit cursor prefix - if (lowerModel.startsWith('cursor-')) { - return 'cursor'; + // Get all registered providers sorted by priority (descending) + const registrations = Array.from(providerRegistry.entries()).sort( + ([, a], [, b]) => (b.priority ?? 0) - (a.priority ?? 0) + ); + + // Check each provider's canHandleModel function + for (const [name, reg] of registrations) { + if (reg.canHandleModel?.(lowerModel)) { + return name as ModelProvider; + } } - // Check if it's a known Cursor model ID (without prefix) - const cursorModelId = lowerModel.replace('cursor-', ''); - if (cursorModelId in CURSOR_MODEL_MAP) { - return 'cursor'; + // Fallback: Check for explicit prefixes + for (const [name] of registrations) { + if (lowerModel.startsWith(`${name}-`)) { + return name as ModelProvider; + } } - // Check for Claude model patterns - if ( - lowerModel.startsWith('claude-') || - ['opus', 'sonnet', 'haiku'].some((n) => lowerModel.includes(n)) - ) { - return 'claude'; - } - - // Default to Claude + // Default to claude (first registered provider or claude) return 'claude'; } @@ -53,19 +79,25 @@ export class ProviderFactory { */ static getProviderForModel(modelId: string): BaseProvider { const providerName = this.getProviderNameForModel(modelId); + const provider = this.getProviderByName(providerName); - if (providerName === 'cursor') { - return new CursorProvider(); + if (!provider) { + // Fallback to claude if provider not found + const claudeReg = providerRegistry.get('claude'); + if (claudeReg) { + return claudeReg.factory(); + } + throw new Error(`No provider found for model: ${modelId}`); } - return new ClaudeProvider(); + return provider; } /** * Get all available providers */ static getAllProviders(): BaseProvider[] { - return [new ClaudeProvider(), new CursorProvider()]; + return Array.from(providerRegistry.values()).map((reg) => reg.factory()); } /** @@ -74,11 +106,10 @@ export class ProviderFactory { * @returns Map of provider name to installation status */ static async checkAllProviders(): Promise> { - const providers = this.getAllProviders(); const statuses: Record = {}; - for (const provider of providers) { - const name = provider.getName(); + for (const [name, reg] of providerRegistry.entries()) { + const provider = reg.factory(); const status = await provider.detectInstallation(); statuses[name] = status; } @@ -89,23 +120,26 @@ export class ProviderFactory { /** * Get provider by name (for direct access if needed) * - * @param name Provider name (e.g., "claude", "cursor") + * @param name Provider name (e.g., "claude", "cursor") or alias (e.g., "anthropic") * @returns Provider instance or null if not found */ static getProviderByName(name: string): BaseProvider | null { const lowerName = name.toLowerCase(); - switch (lowerName) { - case 'claude': - case 'anthropic': - return new ClaudeProvider(); - - case 'cursor': - return new CursorProvider(); - - default: - return null; + // Direct lookup + const directReg = providerRegistry.get(lowerName); + if (directReg) { + return directReg.factory(); } + + // Check aliases + for (const [, reg] of providerRegistry.entries()) { + if (reg.aliases?.includes(lowerName)) { + return reg.factory(); + } + } + + return null; } /** @@ -115,4 +149,46 @@ export class ProviderFactory { const providers = this.getAllProviders(); return providers.flatMap((p) => p.getAvailableModels()); } + + /** + * Get list of registered provider names + */ + static getRegisteredProviderNames(): string[] { + return Array.from(providerRegistry.keys()); + } } + +// ============================================================================= +// Provider Registrations +// ============================================================================= + +// Import providers for registration side-effects +import { ClaudeProvider } from './claude-provider.js'; +import { CursorProvider } from './cursor-provider.js'; + +// Register Claude provider +registerProvider('claude', { + factory: () => new ClaudeProvider(), + aliases: ['anthropic'], + canHandleModel: (model: string) => { + return ( + model.startsWith('claude-') || ['opus', 'sonnet', 'haiku'].some((n) => model.includes(n)) + ); + }, + priority: 0, // Default priority +}); + +// Register Cursor provider +registerProvider('cursor', { + factory: () => new CursorProvider(), + canHandleModel: (model: string) => { + // Check for explicit cursor prefix + if (model.startsWith('cursor-')) { + return true; + } + // Check if it's a known Cursor model ID + const modelId = model.replace('cursor-', ''); + return modelId in CURSOR_MODEL_MAP; + }, + priority: 10, // Higher priority - check Cursor models first +});