Files
claude-task-master/scripts/modules/ai-client-factory.js
Eyal Toledano 11b8d1bda5 feat(ai-client-factory): Add xAI and OpenRouter provider support, enhance tests
- Integrate  for Grok models and  for OpenRouter into the AI client factory ().
- Install necessary provider dependencies (, , and other related  packages, updated  core).
- Update environment variable checks () and client creation logic () for the new providers.
- Add and correct unit tests in  to cover xAI and OpenRouter instantiation, error handling, and environment variable resolution.
- Corrected mock paths and names in tests to align with official package names.
- Verify all tests (28 total) pass for .
- Confirm test coverage remains high (~90%) after additions.
2025-04-19 17:00:47 -04:00

349 lines
12 KiB
JavaScript

import fs from 'fs';
import path from 'path';
import { createOpenAI } from '@ai-sdk/openai';
import { createAnthropic } from '@ai-sdk/anthropic';
import { createGoogle } from '@ai-sdk/google';
import { createPerplexity } from '@ai-sdk/perplexity';
import { createOllama } from 'ollama-ai-provider';
import { createMistral } from '@ai-sdk/mistral';
import { createAzure } from '@ai-sdk/azure';
import { createXai } from '@ai-sdk/xai';
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
// TODO: Add imports for other supported providers like OpenRouter, Grok
import {
getProviderAndModelForRole,
findProjectRoot // Assuming config-manager exports this
} from './config-manager.js';
const clientCache = new Map();
// Using a Symbol for a unique, unmistakable value
const VALIDATION_SKIPPED = Symbol('validation_skipped');
// --- Load Supported Models Data (Lazily) ---
let supportedModelsData = null;
let modelsDataLoaded = false;
function loadSupportedModelsData() {
console.log(
`DEBUG: loadSupportedModelsData called. modelsDataLoaded=${modelsDataLoaded}`
);
if (modelsDataLoaded) {
console.log('DEBUG: Returning cached supported models data.');
return supportedModelsData;
}
try {
const projectRoot = findProjectRoot(process.cwd());
const supportedModelsPath = path.join(
projectRoot,
'data',
'supported-models.json'
);
console.log(
`DEBUG: Checking for supported models at: ${supportedModelsPath}`
);
const exists = fs.existsSync(supportedModelsPath);
console.log(`DEBUG: fs.existsSync result: ${exists}`);
if (exists) {
const fileContent = fs.readFileSync(supportedModelsPath, 'utf-8');
supportedModelsData = JSON.parse(fileContent);
console.log(
'DEBUG: Successfully loaded and parsed supported-models.json'
);
} else {
console.warn(
`Warning: Could not find supported models file at ${supportedModelsPath}. Skipping model validation.`
);
supportedModelsData = {}; // Treat as empty if not found, allowing skip
}
} catch (error) {
console.error(
`Error loading or parsing supported models file: ${error.message}`
);
console.error('Stack Trace:', error.stack);
supportedModelsData = {}; // Treat as empty on error, allowing skip
}
modelsDataLoaded = true;
console.log(
`DEBUG: Setting modelsDataLoaded=true, returning: ${JSON.stringify(supportedModelsData)}`
);
return supportedModelsData;
}
/**
* Validates if a model is supported for a given provider and role.
* @param {string} providerName - The name of the provider.
* @param {string} modelId - The ID of the model.
* @param {string} role - The role ('main', 'research', 'fallback').
* @returns {boolean|Symbol} True if valid, false if invalid, VALIDATION_SKIPPED if data was missing.
*/
function isModelSupportedAndAllowed(providerName, modelId, role) {
const modelsData = loadSupportedModelsData();
if (
!modelsData ||
typeof modelsData !== 'object' ||
Object.keys(modelsData).length === 0
) {
console.warn(
'Skipping model validation as supported models data is unavailable or invalid.'
);
// Return the specific symbol instead of true
return VALIDATION_SKIPPED;
}
// Ensure consistent casing for provider lookup
const providerKey = providerName?.toLowerCase();
if (!providerKey || !modelsData.hasOwnProperty(providerKey)) {
console.warn(
`Provider '${providerName}' not found in supported-models.json.`
);
return false;
}
const providerModels = modelsData[providerKey];
if (!Array.isArray(providerModels)) {
console.warn(
`Invalid format for provider '${providerName}' models in supported-models.json. Expected an array.`
);
return false;
}
const modelInfo = providerModels.find((m) => m && m.id === modelId);
if (!modelInfo) {
console.warn(
`Model '${modelId}' not found for provider '${providerName}' in supported-models.json.`
);
return false;
}
// Check if the role is allowed for this model
if (!Array.isArray(modelInfo.allowed_roles)) {
console.warn(
`Model '${modelId}' (Provider: '${providerName}') has invalid or missing 'allowed_roles' array in supported-models.json.`
);
return false;
}
const isAllowed = modelInfo.allowed_roles.includes(role);
if (!isAllowed) {
console.warn(
`Role '${role}' is not allowed for model '${modelId}' (Provider: '${providerName}'). Allowed roles: ${modelInfo.allowed_roles.join(', ')}`
);
}
return isAllowed;
}
/**
* Resolves an environment variable by checking process.env first, then session.env.
* @param {string} varName - The name of the environment variable.
* @param {object|null} session - The MCP session object (optional).
* @returns {string|undefined} The value of the environment variable or undefined if not found.
*/
function resolveEnvVariable(varName, session) {
return process.env[varName] ?? session?.env?.[varName];
}
/**
* Validates if the required environment variables are set for a given provider,
* checking process.env and falling back to session.env.
* Throws an error if any required variable is missing.
* @param {string} providerName - The name of the provider (e.g., 'openai', 'anthropic').
* @param {object|null} session - The MCP session object (optional).
*/
function validateEnvironment(providerName, session) {
// Define requirements based on the provider
const requirements = {
openai: ['OPENAI_API_KEY'],
anthropic: ['ANTHROPIC_API_KEY'],
google: ['GOOGLE_API_KEY'],
perplexity: ['PERPLEXITY_API_KEY'],
ollama: ['OLLAMA_BASE_URL'], // Ollama only needs Base URL typically
mistral: ['MISTRAL_API_KEY'],
azure: ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT'],
openrouter: ['OPENROUTER_API_KEY'],
xai: ['XAI_API_KEY']
// Add requirements for other providers
};
const providerKey = providerName?.toLowerCase();
if (!providerKey || !requirements[providerKey]) {
// If the provider itself isn't in our requirements list, we can't validate.
// This might happen if config has an unsupported provider. Validation should happen earlier.
// Or, we could throw an error here if the provider is unknown.
console.warn(
`Cannot validate environment for unknown or unsupported provider: ${providerName}`
);
return; // Proceed without validation for unknown providers
}
const missing =
requirements[providerKey]?.filter(
(envVar) => !resolveEnvVariable(envVar, session)
) || [];
if (missing.length > 0) {
throw new Error(
`Missing environment variables for provider '${providerName}': ${missing.join(', ')}. Please check your .env file or session configuration.`
);
}
}
/**
* Creates an AI client instance for the specified provider.
* Assumes environment validation has already passed.
* @param {string} providerName - The name of the provider.
* @param {object|null} session - The MCP session object (optional).
* @param {object} [options={}] - Additional options for the client creation (e.g., model).
* @returns {object} The created AI client instance.
* @throws {Error} If the provider is unsupported.
*/
function createClientInstance(providerName, session, options = {}) {
// Validation is now done before calling this function
const getEnv = (varName) => resolveEnvVariable(varName, session);
switch (providerName?.toLowerCase()) {
case 'openai':
return createOpenAI({ apiKey: getEnv('OPENAI_API_KEY'), ...options });
case 'anthropic':
return createAnthropic({
apiKey: getEnv('ANTHROPIC_API_KEY'),
...options
});
case 'google':
return createGoogle({ apiKey: getEnv('GOOGLE_API_KEY'), ...options });
case 'perplexity':
return createPerplexity({
apiKey: getEnv('PERPLEXITY_API_KEY'),
...options
});
case 'ollama':
const ollamaBaseUrl =
getEnv('OLLAMA_BASE_URL') || 'http://localhost:11434/api'; // Default from ollama-ai-provider docs
// ollama-ai-provider uses baseURL directly
return createOllama({ baseURL: ollamaBaseUrl, ...options });
case 'mistral':
return createMistral({ apiKey: getEnv('MISTRAL_API_KEY'), ...options });
case 'azure':
return createAzure({
apiKey: getEnv('AZURE_OPENAI_API_KEY'),
endpoint: getEnv('AZURE_OPENAI_ENDPOINT'),
...(options.model && { deploymentName: options.model }), // Azure often uses deployment name
...options
});
case 'openrouter':
return createOpenRouter({
apiKey: getEnv('OPENROUTER_API_KEY'),
...options
});
case 'xai':
return createXai({ apiKey: getEnv('XAI_API_KEY'), ...options });
// TODO: Add cases for OpenRouter, Grok
default:
throw new Error(`Unsupported AI provider specified: ${providerName}`);
}
}
/**
* Gets or creates an AI client instance based on the configured model for a specific role.
* Validates the configured model against supported models and role allowances.
* @param {string} role - The role ('main', 'research', or 'fallback').
* @param {object|null} [session=null] - The MCP session object (optional).
* @param {object} [overrideOptions={}] - Optional overrides for { provider, modelId }.
* @returns {object} The cached or newly created AI client instance.
* @throws {Error} If configuration is missing, invalid, or environment validation fails.
*/
export function getClient(role, session = null, overrideOptions = {}) {
if (!role) {
throw new Error(
`Client role ('main', 'research', 'fallback') must be specified.`
);
}
// 1. Determine Provider and Model ID
let providerName = overrideOptions.provider;
let modelId = overrideOptions.modelId;
if (!providerName || !modelId) {
// If not fully overridden, get from config
try {
const config = getProviderAndModelForRole(role); // Fetch from config manager
providerName = providerName || config.provider;
modelId = modelId || config.modelId;
} catch (configError) {
throw new Error(
`Failed to get configuration for role '${role}': ${configError.message}`
);
}
}
if (!providerName || !modelId) {
throw new Error(
`Could not determine provider or modelId for role '${role}' from configuration or overrides.`
);
}
// 2. Validate Provider/Model Combination and Role Allowance
const validationResult = isModelSupportedAndAllowed(
providerName,
modelId,
role
);
// Only throw if validation explicitly returned false (meaning invalid/disallowed)
// If it returned VALIDATION_SKIPPED, we proceed but skip strict validation.
if (validationResult === false) {
throw new Error(
`Model '${modelId}' from provider '${providerName}' is either not supported or not allowed for the '${role}' role. Check supported-models.json and your .taskmasterconfig.`
);
}
// Note: If validationResult === VALIDATION_SKIPPED, we continue to env validation
// 3. Validate Environment Variables for the chosen provider
try {
validateEnvironment(providerName, session);
} catch (envError) {
// Re-throw the original environment error for clearer test messages
throw envError;
}
// 4. Check Cache
const cacheKey = `${providerName.toLowerCase()}:${modelId}`;
if (clientCache.has(cacheKey)) {
return clientCache.get(cacheKey);
}
// 5. Create New Client Instance
console.log(
`Creating new client for role '${role}': Provider=${providerName}, Model=${modelId}`
);
try {
const clientInstance = createClientInstance(providerName, session, {
model: modelId
});
clientCache.set(cacheKey, clientInstance);
return clientInstance;
} catch (creationError) {
throw new Error(
`Failed to create client instance for provider '${providerName}' (role: '${role}'): ${creationError.message}`
);
}
}
// Optional: Function to clear the cache if needed
export function clearClientCache() {
clientCache.clear();
console.log('AI client cache cleared.');
}
// Exported for testing purposes only
export function _resetSupportedModelsCache() {
console.log('DEBUG: Resetting supported models cache...');
supportedModelsData = null;
modelsDataLoaded = false;
console.log('DEBUG: Supported models cache reset.');
}