Compare commits

...

2 Commits

Author SHA1 Message Date
Ralph Khreish
f2f42b0659 chore: cleanup 2025-06-20 15:57:57 +03:00
Ralph Khreish
6b6fe327d3 fix: Fix external provider support 2025-06-20 15:16:54 +03:00
5 changed files with 87 additions and 85 deletions

View File

@@ -13,6 +13,41 @@ import {
disableSilentMode disableSilentMode
} from '../../../../scripts/modules/utils.js'; } from '../../../../scripts/modules/utils.js';
import { createLogWrapper } from '../../tools/utils.js'; import { createLogWrapper } from '../../tools/utils.js';
import { CUSTOM_PROVIDERS_ARRAY } from '../../../../src/constants/providers.js';
// Define supported roles for model setting
const MODEL_ROLES = ['main', 'research', 'fallback'];
/**
* Determine provider hint from custom provider flags
* @param {Object} args - Arguments containing provider flags
* @returns {string|undefined} Provider hint or undefined if no custom provider flag is set
*/
function getProviderHint(args) {
return CUSTOM_PROVIDERS_ARRAY.find((provider) => args[provider]);
}
/**
* Handle setting models for different roles
* @param {Object} args - Arguments containing role-specific model IDs
* @param {Object} context - Context object with session, mcpLog, projectRoot
* @returns {Object|null} Result if a model was set, null if no model setting was requested
*/
async function handleModelSetting(args, context) {
for (const role of MODEL_ROLES) {
const roleKey = `set${role.charAt(0).toUpperCase() + role.slice(1)}`; // setMain, setResearch, setFallback
if (args[roleKey]) {
const providerHint = getProviderHint(args);
return await setModel(role, args[roleKey], {
...context,
providerHint
});
}
}
return null; // No model setting was requested
}
/** /**
* Get or update model configuration * Get or update model configuration
@@ -31,16 +66,21 @@ export async function modelsDirect(args, log, context = {}) {
log.info(`Executing models_direct with args: ${JSON.stringify(args)}`); log.info(`Executing models_direct with args: ${JSON.stringify(args)}`);
log.info(`Using project root: ${projectRoot}`); log.info(`Using project root: ${projectRoot}`);
// Validate flags: cannot use both openrouter and ollama simultaneously // Validate flags: only one custom provider flag can be used simultaneously
if (args.openrouter && args.ollama) { const customProviderFlags = CUSTOM_PROVIDERS_ARRAY.filter(
(provider) => args[provider]
);
if (customProviderFlags.length > 1) {
log.error( log.error(
'Error: Cannot use both openrouter and ollama flags simultaneously.' 'Error: Cannot use multiple custom provider flags simultaneously.'
); );
return { return {
success: false, success: false,
error: { error: {
code: 'INVALID_ARGS', code: 'INVALID_ARGS',
message: 'Cannot use both openrouter and ollama flags simultaneously.' message:
'Cannot use multiple custom provider flags simultaneously. Choose only one: openrouter, ollama, bedrock, azure, or vertex.'
} }
}; };
} }
@@ -54,55 +94,22 @@ export async function modelsDirect(args, log, context = {}) {
return await getAvailableModelsList({ return await getAvailableModelsList({
session, session,
mcpLog, mcpLog,
projectRoot // Pass projectRoot to function projectRoot
}); });
} }
// Handle setting a specific model // Handle setting any model role using unified function
if (args.setMain) { const modelContext = { session, mcpLog, projectRoot };
return await setModel('main', args.setMain, { const modelSetResult = await handleModelSetting(args, modelContext);
session, if (modelSetResult) {
mcpLog, return modelSetResult;
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
}
if (args.setResearch) {
return await setModel('research', args.setResearch, {
session,
mcpLog,
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
}
if (args.setFallback) {
return await setModel('fallback', args.setFallback, {
session,
mcpLog,
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
} }
// Default action: get current configuration // Default action: get current configuration
return await getModelConfiguration({ return await getModelConfiguration({
session, session,
mcpLog, mcpLog,
projectRoot // Pass projectRoot to function projectRoot
}); });
} finally { } finally {
disableSilentMode(); disableSilentMode();

View File

@@ -55,7 +55,21 @@ export function registerModelsTool(server) {
ollama: z ollama: z
.boolean() .boolean()
.optional() .optional()
.describe('Indicates the set model ID is a custom Ollama model.') .describe('Indicates the set model ID is a custom Ollama model.'),
bedrock: z
.boolean()
.optional()
.describe('Indicates the set model ID is a custom AWS Bedrock model.'),
azure: z
.boolean()
.optional()
.describe('Indicates the set model ID is a custom Azure OpenAI model.'),
vertex: z
.boolean()
.optional()
.describe(
'Indicates the set model ID is a custom Google Vertex AI model.'
)
}), }),
execute: withNormalizedProjectRoot(async (args, { log, session }) => { execute: withNormalizedProjectRoot(async (args, { log, session }) => {
try { try {

View File

@@ -294,30 +294,14 @@ async function runInteractiveSetup(projectRoot) {
} }
: null; : null;
const customOpenRouterOption = { // Define custom provider options
name: '* Custom OpenRouter model', // Symbol updated const customProviderOptions = [
value: '__CUSTOM_OPENROUTER__' { name: '* Custom OpenRouter model', value: '__CUSTOM_OPENROUTER__' },
}; { name: '* Custom Ollama model', value: '__CUSTOM_OLLAMA__' },
{ name: '* Custom Bedrock model', value: '__CUSTOM_BEDROCK__' },
const customOllamaOption = { { name: '* Custom Azure model', value: '__CUSTOM_AZURE__' },
name: '* Custom Ollama model', // Symbol updated { name: '* Custom Vertex model', value: '__CUSTOM_VERTEX__' }
value: '__CUSTOM_OLLAMA__' ];
};
const customBedrockOption = {
name: '* Custom Bedrock model', // Add Bedrock custom option
value: '__CUSTOM_BEDROCK__'
};
const customAzureOption = {
name: '* Custom Azure OpenAI model', // Add Azure custom option
value: '__CUSTOM_AZURE__'
};
const customVertexOption = {
name: '* Custom Vertex AI model', // Add Vertex custom option
value: '__CUSTOM_VERTEX__'
};
let choices = []; let choices = [];
let defaultIndex = 0; // Default to 'Cancel' let defaultIndex = 0; // Default to 'Cancel'
@@ -364,24 +348,16 @@ async function runInteractiveSetup(projectRoot) {
} }
systemOptions.push(cancelOption); systemOptions.push(cancelOption);
const customOptions = [
customOpenRouterOption,
customOllamaOption,
customBedrockOption,
customAzureOption,
customVertexOption
];
const systemLength = systemOptions.length; const systemLength = systemOptions.length;
if (allowNone) { if (allowNone) {
choices = [ choices = [
...systemOptions, ...systemOptions,
new inquirer.Separator('── Standard Models ──'), new inquirer.Separator('\n── Standard Models ──'),
{ name: '⚪ None (disable)', value: null }, { name: '⚪ None (disable)', value: null },
...roleChoices, ...roleChoices,
new inquirer.Separator('── Custom Providers ──'), new inquirer.Separator('\n── Custom Providers ──'),
...customOptions ...customProviderOptions
]; ];
// Adjust default index: System + Sep1 + None (+2) // Adjust default index: System + Sep1 + None (+2)
const noneOptionIndex = systemLength + 1; const noneOptionIndex = systemLength + 1;
@@ -392,10 +368,10 @@ async function runInteractiveSetup(projectRoot) {
} else { } else {
choices = [ choices = [
...systemOptions, ...systemOptions,
new inquirer.Separator('── Standard Models ──'), new inquirer.Separator('\n── Standard Models ──'),
...roleChoices, ...roleChoices,
new inquirer.Separator('── Custom Providers ──'), new inquirer.Separator('\n── Custom Providers ──'),
...customOptions ...customProviderOptions
]; ];
// Adjust default index: System + Sep (+1) // Adjust default index: System + Sep (+1)
defaultIndex = defaultIndex =

View File

@@ -509,7 +509,8 @@ function isApiKeySet(providerName, session = null, projectRoot = null) {
azure: 'AZURE_OPENAI_API_KEY', azure: 'AZURE_OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY', openrouter: 'OPENROUTER_API_KEY',
xai: 'XAI_API_KEY', xai: 'XAI_API_KEY',
vertex: 'GOOGLE_API_KEY' // Vertex uses the same key as Google vertex: 'GOOGLE_API_KEY', // Vertex uses the same key as Google
bedrock: 'AWS_ACCESS_KEY_ID' // Bedrock uses AWS credentials
// Add other providers as needed // Add other providers as needed
}; };
@@ -605,6 +606,10 @@ function getMcpApiKeyStatus(providerName, projectRoot = null) {
apiKeyToCheck = mcpEnv.GOOGLE_API_KEY; // Vertex uses Google API key apiKeyToCheck = mcpEnv.GOOGLE_API_KEY; // Vertex uses Google API key
placeholderValue = 'YOUR_GOOGLE_API_KEY_HERE'; placeholderValue = 'YOUR_GOOGLE_API_KEY_HERE';
break; break;
case 'bedrock':
apiKeyToCheck = mcpEnv.AWS_ACCESS_KEY_ID; // Bedrock uses AWS credentials
placeholderValue = 'YOUR_AWS_ACCESS_KEY_ID_HERE';
break;
default: default:
return false; // Unknown provider return false; // Unknown provider
} }

View File

@@ -525,7 +525,7 @@ async function setModel(role, modelId, options = {}) {
success: false, success: false,
error: { error: {
code: 'MODEL_NOT_FOUND_NO_HINT', code: 'MODEL_NOT_FOUND_NO_HINT',
message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter or --ollama.` message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter, --ollama, --bedrock, --azure, or --vertex.`
} }
}; };
} }