From 2bb4260966d60aa3d8b444a9c76b5e341f6ea5f0 Mon Sep 17 00:00:00 2001 From: Ralph Khreish <35776126+Crunchyman-ralph@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:59:53 +0300 Subject: [PATCH] fix: Fix external provider support (#726) --- .../src/core/direct-functions/models.js | 95 ++++++++++--------- mcp-server/src/tools/models.js | 16 +++- scripts/modules/commands.js | 52 +++------- scripts/modules/config-manager.js | 7 +- scripts/modules/task-manager/models.js | 2 +- 5 files changed, 87 insertions(+), 85 deletions(-) diff --git a/mcp-server/src/core/direct-functions/models.js b/mcp-server/src/core/direct-functions/models.js index aa0dcff2..f5d43eea 100644 --- a/mcp-server/src/core/direct-functions/models.js +++ b/mcp-server/src/core/direct-functions/models.js @@ -13,6 +13,41 @@ import { disableSilentMode } from '../../../../scripts/modules/utils.js'; import { createLogWrapper } from '../../tools/utils.js'; +import { CUSTOM_PROVIDERS_ARRAY } from '../../../../src/constants/providers.js'; + +// Define supported roles for model setting +const MODEL_ROLES = ['main', 'research', 'fallback']; + +/** + * Determine provider hint from custom provider flags + * @param {Object} args - Arguments containing provider flags + * @returns {string|undefined} Provider hint or undefined if no custom provider flag is set + */ +function getProviderHint(args) { + return CUSTOM_PROVIDERS_ARRAY.find((provider) => args[provider]); +} + +/** + * Handle setting models for different roles + * @param {Object} args - Arguments containing role-specific model IDs + * @param {Object} context - Context object with session, mcpLog, projectRoot + * @returns {Object|null} Result if a model was set, null if no model setting was requested + */ +async function handleModelSetting(args, context) { + for (const role of MODEL_ROLES) { + const roleKey = `set${role.charAt(0).toUpperCase() + role.slice(1)}`; // setMain, setResearch, setFallback + + if (args[roleKey]) { + const providerHint = getProviderHint(args); + + return await setModel(role, args[roleKey], { + ...context, + providerHint + }); + } + } + return null; // No model setting was requested +} /** * Get or update model configuration @@ -31,16 +66,21 @@ export async function modelsDirect(args, log, context = {}) { log.info(`Executing models_direct with args: ${JSON.stringify(args)}`); log.info(`Using project root: ${projectRoot}`); - // Validate flags: cannot use both openrouter and ollama simultaneously - if (args.openrouter && args.ollama) { + // Validate flags: only one custom provider flag can be used simultaneously + const customProviderFlags = CUSTOM_PROVIDERS_ARRAY.filter( + (provider) => args[provider] + ); + + if (customProviderFlags.length > 1) { log.error( - 'Error: Cannot use both openrouter and ollama flags simultaneously.' + 'Error: Cannot use multiple custom provider flags simultaneously.' ); return { success: false, error: { code: 'INVALID_ARGS', - message: 'Cannot use both openrouter and ollama flags simultaneously.' + message: + 'Cannot use multiple custom provider flags simultaneously. Choose only one: openrouter, ollama, bedrock, azure, or vertex.' } }; } @@ -54,55 +94,22 @@ export async function modelsDirect(args, log, context = {}) { return await getAvailableModelsList({ session, mcpLog, - projectRoot // Pass projectRoot to function + projectRoot }); } - // Handle setting a specific model - if (args.setMain) { - return await setModel('main', args.setMain, { - session, - mcpLog, - projectRoot, // Pass projectRoot to function - providerHint: args.openrouter - ? 'openrouter' - : args.ollama - ? 'ollama' - : undefined // Pass hint - }); - } - - if (args.setResearch) { - return await setModel('research', args.setResearch, { - session, - mcpLog, - projectRoot, // Pass projectRoot to function - providerHint: args.openrouter - ? 'openrouter' - : args.ollama - ? 'ollama' - : undefined // Pass hint - }); - } - - if (args.setFallback) { - return await setModel('fallback', args.setFallback, { - session, - mcpLog, - projectRoot, // Pass projectRoot to function - providerHint: args.openrouter - ? 'openrouter' - : args.ollama - ? 'ollama' - : undefined // Pass hint - }); + // Handle setting any model role using unified function + const modelContext = { session, mcpLog, projectRoot }; + const modelSetResult = await handleModelSetting(args, modelContext); + if (modelSetResult) { + return modelSetResult; } // Default action: get current configuration return await getModelConfiguration({ session, mcpLog, - projectRoot // Pass projectRoot to function + projectRoot }); } finally { disableSilentMode(); diff --git a/mcp-server/src/tools/models.js b/mcp-server/src/tools/models.js index ef2ba24f..e38ff308 100644 --- a/mcp-server/src/tools/models.js +++ b/mcp-server/src/tools/models.js @@ -55,7 +55,21 @@ export function registerModelsTool(server) { ollama: z .boolean() .optional() - .describe('Indicates the set model ID is a custom Ollama model.') + .describe('Indicates the set model ID is a custom Ollama model.'), + bedrock: z + .boolean() + .optional() + .describe('Indicates the set model ID is a custom AWS Bedrock model.'), + azure: z + .boolean() + .optional() + .describe('Indicates the set model ID is a custom Azure OpenAI model.'), + vertex: z + .boolean() + .optional() + .describe( + 'Indicates the set model ID is a custom Google Vertex AI model.' + ) }), execute: withNormalizedProjectRoot(async (args, { log, session }) => { try { diff --git a/scripts/modules/commands.js b/scripts/modules/commands.js index 5000ac78..b56ae1b7 100644 --- a/scripts/modules/commands.js +++ b/scripts/modules/commands.js @@ -294,30 +294,14 @@ async function runInteractiveSetup(projectRoot) { } : null; - const customOpenRouterOption = { - name: '* Custom OpenRouter model', // Symbol updated - value: '__CUSTOM_OPENROUTER__' - }; - - const customOllamaOption = { - name: '* Custom Ollama model', // Symbol updated - value: '__CUSTOM_OLLAMA__' - }; - - const customBedrockOption = { - name: '* Custom Bedrock model', // Add Bedrock custom option - value: '__CUSTOM_BEDROCK__' - }; - - const customAzureOption = { - name: '* Custom Azure OpenAI model', // Add Azure custom option - value: '__CUSTOM_AZURE__' - }; - - const customVertexOption = { - name: '* Custom Vertex AI model', // Add Vertex custom option - value: '__CUSTOM_VERTEX__' - }; + // Define custom provider options + const customProviderOptions = [ + { name: '* Custom OpenRouter model', value: '__CUSTOM_OPENROUTER__' }, + { name: '* Custom Ollama model', value: '__CUSTOM_OLLAMA__' }, + { name: '* Custom Bedrock model', value: '__CUSTOM_BEDROCK__' }, + { name: '* Custom Azure model', value: '__CUSTOM_AZURE__' }, + { name: '* Custom Vertex model', value: '__CUSTOM_VERTEX__' } + ]; let choices = []; let defaultIndex = 0; // Default to 'Cancel' @@ -364,24 +348,16 @@ async function runInteractiveSetup(projectRoot) { } systemOptions.push(cancelOption); - const customOptions = [ - customOpenRouterOption, - customOllamaOption, - customBedrockOption, - customAzureOption, - customVertexOption - ]; - const systemLength = systemOptions.length; if (allowNone) { choices = [ ...systemOptions, - new inquirer.Separator('── Standard Models ──'), + new inquirer.Separator('\n── Standard Models ──'), { name: '⚪ None (disable)', value: null }, ...roleChoices, - new inquirer.Separator('── Custom Providers ──'), - ...customOptions + new inquirer.Separator('\n── Custom Providers ──'), + ...customProviderOptions ]; // Adjust default index: System + Sep1 + None (+2) const noneOptionIndex = systemLength + 1; @@ -392,10 +368,10 @@ async function runInteractiveSetup(projectRoot) { } else { choices = [ ...systemOptions, - new inquirer.Separator('── Standard Models ──'), + new inquirer.Separator('\n── Standard Models ──'), ...roleChoices, - new inquirer.Separator('── Custom Providers ──'), - ...customOptions + new inquirer.Separator('\n── Custom Providers ──'), + ...customProviderOptions ]; // Adjust default index: System + Sep (+1) defaultIndex = diff --git a/scripts/modules/config-manager.js b/scripts/modules/config-manager.js index d289e2e2..82959146 100644 --- a/scripts/modules/config-manager.js +++ b/scripts/modules/config-manager.js @@ -509,7 +509,8 @@ function isApiKeySet(providerName, session = null, projectRoot = null) { azure: 'AZURE_OPENAI_API_KEY', openrouter: 'OPENROUTER_API_KEY', xai: 'XAI_API_KEY', - vertex: 'GOOGLE_API_KEY' // Vertex uses the same key as Google + vertex: 'GOOGLE_API_KEY', // Vertex uses the same key as Google + bedrock: 'AWS_ACCESS_KEY_ID' // Bedrock uses AWS credentials // Add other providers as needed }; @@ -605,6 +606,10 @@ function getMcpApiKeyStatus(providerName, projectRoot = null) { apiKeyToCheck = mcpEnv.GOOGLE_API_KEY; // Vertex uses Google API key placeholderValue = 'YOUR_GOOGLE_API_KEY_HERE'; break; + case 'bedrock': + apiKeyToCheck = mcpEnv.AWS_ACCESS_KEY_ID; // Bedrock uses AWS credentials + placeholderValue = 'YOUR_AWS_ACCESS_KEY_ID_HERE'; + break; default: return false; // Unknown provider } diff --git a/scripts/modules/task-manager/models.js b/scripts/modules/task-manager/models.js index 6bfdfae8..0a12c331 100644 --- a/scripts/modules/task-manager/models.js +++ b/scripts/modules/task-manager/models.js @@ -525,7 +525,7 @@ async function setModel(role, modelId, options = {}) { success: false, error: { code: 'MODEL_NOT_FOUND_NO_HINT', - message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter or --ollama.` + message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter, --ollama, --bedrock, --azure, or --vertex.` } }; }