feat(providers): Add provider registry pattern

Replace hardcoded switch statements with dynamic registry pattern.
Providers register with factory using registerProvider() function.

New features:
- registerProvider() function for dynamic registration
- canHandleModel() callback for model routing
- priority field for controlling match order
- aliases support (e.g., 'anthropic' -> 'claude')
- getRegisteredProviderNames() for introspection

Adding new providers now only requires calling registerProvider()
with a factory function and model matching logic.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Kacper
2025-12-30 00:42:17 +01:00
parent 55bd9b0dc7
commit dc8c06e447

View File

@@ -1,17 +1,43 @@
/** /**
* Provider Factory - Routes model IDs to the appropriate provider * Provider Factory - Routes model IDs to the appropriate provider
* *
* This factory implements model-based routing to automatically select * Uses a registry pattern for dynamic provider registration.
* the correct provider based on the model string. This makes adding * Providers register themselves on import, making it easy to add new providers.
* new providers (Cursor, OpenCode, etc.) trivial - just add one line.
*/ */
import { BaseProvider } from './base-provider.js'; import { BaseProvider } from './base-provider.js';
import { ClaudeProvider } from './claude-provider.js';
import { CursorProvider } from './cursor-provider.js';
import type { InstallationStatus, ModelDefinition } from './types.js'; import type { InstallationStatus, ModelDefinition } from './types.js';
import { CURSOR_MODEL_MAP, type ModelProvider } from '@automaker/types'; import { CURSOR_MODEL_MAP, type ModelProvider } from '@automaker/types';
/**
* Provider registration entry
*/
interface ProviderRegistration {
/** Factory function to create provider instance */
factory: () => BaseProvider;
/** Aliases for this provider (e.g., 'anthropic' for 'claude') */
aliases?: string[];
/** Function to check if this provider can handle a model ID */
canHandleModel?: (modelId: string) => boolean;
/** Priority for model matching (higher = checked first) */
priority?: number;
}
/**
* Provider registry - stores registered providers
*/
const providerRegistry = new Map<string, ProviderRegistration>();
/**
* Register a provider with the factory
*
* @param name Provider name (e.g., 'claude', 'cursor')
* @param registration Provider registration config
*/
export function registerProvider(name: string, registration: ProviderRegistration): void {
providerRegistry.set(name.toLowerCase(), registration);
}
export class ProviderFactory { export class ProviderFactory {
/** /**
* Determine which provider to use for a given model * Determine which provider to use for a given model
@@ -22,26 +48,26 @@ export class ProviderFactory {
static getProviderNameForModel(model: string): ModelProvider { static getProviderNameForModel(model: string): ModelProvider {
const lowerModel = model.toLowerCase(); const lowerModel = model.toLowerCase();
// Check for explicit cursor prefix // Get all registered providers sorted by priority (descending)
if (lowerModel.startsWith('cursor-')) { const registrations = Array.from(providerRegistry.entries()).sort(
return 'cursor'; ([, a], [, b]) => (b.priority ?? 0) - (a.priority ?? 0)
);
// Check each provider's canHandleModel function
for (const [name, reg] of registrations) {
if (reg.canHandleModel?.(lowerModel)) {
return name as ModelProvider;
}
} }
// Check if it's a known Cursor model ID (without prefix) // Fallback: Check for explicit prefixes
const cursorModelId = lowerModel.replace('cursor-', ''); for (const [name] of registrations) {
if (cursorModelId in CURSOR_MODEL_MAP) { if (lowerModel.startsWith(`${name}-`)) {
return 'cursor'; return name as ModelProvider;
}
} }
// Check for Claude model patterns // Default to claude (first registered provider or claude)
if (
lowerModel.startsWith('claude-') ||
['opus', 'sonnet', 'haiku'].some((n) => lowerModel.includes(n))
) {
return 'claude';
}
// Default to Claude
return 'claude'; return 'claude';
} }
@@ -53,19 +79,25 @@ export class ProviderFactory {
*/ */
static getProviderForModel(modelId: string): BaseProvider { static getProviderForModel(modelId: string): BaseProvider {
const providerName = this.getProviderNameForModel(modelId); const providerName = this.getProviderNameForModel(modelId);
const provider = this.getProviderByName(providerName);
if (providerName === 'cursor') { if (!provider) {
return new CursorProvider(); // Fallback to claude if provider not found
const claudeReg = providerRegistry.get('claude');
if (claudeReg) {
return claudeReg.factory();
}
throw new Error(`No provider found for model: ${modelId}`);
} }
return new ClaudeProvider(); return provider;
} }
/** /**
* Get all available providers * Get all available providers
*/ */
static getAllProviders(): BaseProvider[] { static getAllProviders(): BaseProvider[] {
return [new ClaudeProvider(), new CursorProvider()]; return Array.from(providerRegistry.values()).map((reg) => reg.factory());
} }
/** /**
@@ -74,11 +106,10 @@ export class ProviderFactory {
* @returns Map of provider name to installation status * @returns Map of provider name to installation status
*/ */
static async checkAllProviders(): Promise<Record<string, InstallationStatus>> { static async checkAllProviders(): Promise<Record<string, InstallationStatus>> {
const providers = this.getAllProviders();
const statuses: Record<string, InstallationStatus> = {}; const statuses: Record<string, InstallationStatus> = {};
for (const provider of providers) { for (const [name, reg] of providerRegistry.entries()) {
const name = provider.getName(); const provider = reg.factory();
const status = await provider.detectInstallation(); const status = await provider.detectInstallation();
statuses[name] = status; statuses[name] = status;
} }
@@ -89,23 +120,26 @@ export class ProviderFactory {
/** /**
* Get provider by name (for direct access if needed) * Get provider by name (for direct access if needed)
* *
* @param name Provider name (e.g., "claude", "cursor") * @param name Provider name (e.g., "claude", "cursor") or alias (e.g., "anthropic")
* @returns Provider instance or null if not found * @returns Provider instance or null if not found
*/ */
static getProviderByName(name: string): BaseProvider | null { static getProviderByName(name: string): BaseProvider | null {
const lowerName = name.toLowerCase(); const lowerName = name.toLowerCase();
switch (lowerName) { // Direct lookup
case 'claude': const directReg = providerRegistry.get(lowerName);
case 'anthropic': if (directReg) {
return new ClaudeProvider(); return directReg.factory();
case 'cursor':
return new CursorProvider();
default:
return null;
} }
// Check aliases
for (const [, reg] of providerRegistry.entries()) {
if (reg.aliases?.includes(lowerName)) {
return reg.factory();
}
}
return null;
} }
/** /**
@@ -115,4 +149,46 @@ export class ProviderFactory {
const providers = this.getAllProviders(); const providers = this.getAllProviders();
return providers.flatMap((p) => p.getAvailableModels()); return providers.flatMap((p) => p.getAvailableModels());
} }
/**
* Get list of registered provider names
*/
static getRegisteredProviderNames(): string[] {
return Array.from(providerRegistry.keys());
} }
}
// =============================================================================
// Provider Registrations
// =============================================================================
// Import providers for registration side-effects
import { ClaudeProvider } from './claude-provider.js';
import { CursorProvider } from './cursor-provider.js';
// Register Claude provider
registerProvider('claude', {
factory: () => new ClaudeProvider(),
aliases: ['anthropic'],
canHandleModel: (model: string) => {
return (
model.startsWith('claude-') || ['opus', 'sonnet', 'haiku'].some((n) => model.includes(n))
);
},
priority: 0, // Default priority
});
// Register Cursor provider
registerProvider('cursor', {
factory: () => new CursorProvider(),
canHandleModel: (model: string) => {
// Check for explicit cursor prefix
if (model.startsWith('cursor-')) {
return true;
}
// Check if it's a known Cursor model ID
const modelId = model.replace('cursor-', '');
return modelId in CURSOR_MODEL_MAP;
},
priority: 10, // Higher priority - check Cursor models first
});