Compare commits

..

3 Commits

Author SHA1 Message Date
Ralph Khreish
c5e1054b3c chore: fix CI 2025-06-20 14:12:47 +03:00
Ralph Khreish
52e6ef8792 chore: improve changelog 2025-06-20 14:04:23 +03:00
Ralph Khreish
51ce1f59de fix: providers config for azure, bedrock, and vertex 2025-06-20 14:04:23 +03:00
5 changed files with 85 additions and 87 deletions

View File

@@ -13,41 +13,6 @@ import {
disableSilentMode disableSilentMode
} from '../../../../scripts/modules/utils.js'; } from '../../../../scripts/modules/utils.js';
import { createLogWrapper } from '../../tools/utils.js'; import { createLogWrapper } from '../../tools/utils.js';
import { CUSTOM_PROVIDERS_ARRAY } from '../../../../src/constants/providers.js';
// Define supported roles for model setting
const MODEL_ROLES = ['main', 'research', 'fallback'];
/**
* Determine provider hint from custom provider flags
* @param {Object} args - Arguments containing provider flags
* @returns {string|undefined} Provider hint or undefined if no custom provider flag is set
*/
function getProviderHint(args) {
return CUSTOM_PROVIDERS_ARRAY.find((provider) => args[provider]);
}
/**
* Handle setting models for different roles
* @param {Object} args - Arguments containing role-specific model IDs
* @param {Object} context - Context object with session, mcpLog, projectRoot
* @returns {Object|null} Result if a model was set, null if no model setting was requested
*/
async function handleModelSetting(args, context) {
for (const role of MODEL_ROLES) {
const roleKey = `set${role.charAt(0).toUpperCase() + role.slice(1)}`; // setMain, setResearch, setFallback
if (args[roleKey]) {
const providerHint = getProviderHint(args);
return await setModel(role, args[roleKey], {
...context,
providerHint
});
}
}
return null; // No model setting was requested
}
/** /**
* Get or update model configuration * Get or update model configuration
@@ -66,21 +31,16 @@ export async function modelsDirect(args, log, context = {}) {
log.info(`Executing models_direct with args: ${JSON.stringify(args)}`); log.info(`Executing models_direct with args: ${JSON.stringify(args)}`);
log.info(`Using project root: ${projectRoot}`); log.info(`Using project root: ${projectRoot}`);
// Validate flags: only one custom provider flag can be used simultaneously // Validate flags: cannot use both openrouter and ollama simultaneously
const customProviderFlags = CUSTOM_PROVIDERS_ARRAY.filter( if (args.openrouter && args.ollama) {
(provider) => args[provider]
);
if (customProviderFlags.length > 1) {
log.error( log.error(
'Error: Cannot use multiple custom provider flags simultaneously.' 'Error: Cannot use both openrouter and ollama flags simultaneously.'
); );
return { return {
success: false, success: false,
error: { error: {
code: 'INVALID_ARGS', code: 'INVALID_ARGS',
message: message: 'Cannot use both openrouter and ollama flags simultaneously.'
'Cannot use multiple custom provider flags simultaneously. Choose only one: openrouter, ollama, bedrock, azure, or vertex.'
} }
}; };
} }
@@ -94,22 +54,55 @@ export async function modelsDirect(args, log, context = {}) {
return await getAvailableModelsList({ return await getAvailableModelsList({
session, session,
mcpLog, mcpLog,
projectRoot projectRoot // Pass projectRoot to function
}); });
} }
// Handle setting any model role using unified function // Handle setting a specific model
const modelContext = { session, mcpLog, projectRoot }; if (args.setMain) {
const modelSetResult = await handleModelSetting(args, modelContext); return await setModel('main', args.setMain, {
if (modelSetResult) { session,
return modelSetResult; mcpLog,
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
}
if (args.setResearch) {
return await setModel('research', args.setResearch, {
session,
mcpLog,
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
}
if (args.setFallback) {
return await setModel('fallback', args.setFallback, {
session,
mcpLog,
projectRoot, // Pass projectRoot to function
providerHint: args.openrouter
? 'openrouter'
: args.ollama
? 'ollama'
: undefined // Pass hint
});
} }
// Default action: get current configuration // Default action: get current configuration
return await getModelConfiguration({ return await getModelConfiguration({
session, session,
mcpLog, mcpLog,
projectRoot projectRoot // Pass projectRoot to function
}); });
} finally { } finally {
disableSilentMode(); disableSilentMode();

View File

@@ -55,21 +55,7 @@ export function registerModelsTool(server) {
ollama: z ollama: z
.boolean() .boolean()
.optional() .optional()
.describe('Indicates the set model ID is a custom Ollama model.'), .describe('Indicates the set model ID is a custom Ollama model.')
bedrock: z
.boolean()
.optional()
.describe('Indicates the set model ID is a custom AWS Bedrock model.'),
azure: z
.boolean()
.optional()
.describe('Indicates the set model ID is a custom Azure OpenAI model.'),
vertex: z
.boolean()
.optional()
.describe(
'Indicates the set model ID is a custom Google Vertex AI model.'
)
}), }),
execute: withNormalizedProjectRoot(async (args, { log, session }) => { execute: withNormalizedProjectRoot(async (args, { log, session }) => {
try { try {

View File

@@ -294,14 +294,30 @@ async function runInteractiveSetup(projectRoot) {
} }
: null; : null;
// Define custom provider options const customOpenRouterOption = {
const customProviderOptions = [ name: '* Custom OpenRouter model', // Symbol updated
{ name: '* Custom OpenRouter model', value: '__CUSTOM_OPENROUTER__' }, value: '__CUSTOM_OPENROUTER__'
{ name: '* Custom Ollama model', value: '__CUSTOM_OLLAMA__' }, };
{ name: '* Custom Bedrock model', value: '__CUSTOM_BEDROCK__' },
{ name: '* Custom Azure model', value: '__CUSTOM_AZURE__' }, const customOllamaOption = {
{ name: '* Custom Vertex model', value: '__CUSTOM_VERTEX__' } name: '* Custom Ollama model', // Symbol updated
]; value: '__CUSTOM_OLLAMA__'
};
const customBedrockOption = {
name: '* Custom Bedrock model', // Add Bedrock custom option
value: '__CUSTOM_BEDROCK__'
};
const customAzureOption = {
name: '* Custom Azure OpenAI model', // Add Azure custom option
value: '__CUSTOM_AZURE__'
};
const customVertexOption = {
name: '* Custom Vertex AI model', // Add Vertex custom option
value: '__CUSTOM_VERTEX__'
};
let choices = []; let choices = [];
let defaultIndex = 0; // Default to 'Cancel' let defaultIndex = 0; // Default to 'Cancel'
@@ -348,16 +364,24 @@ async function runInteractiveSetup(projectRoot) {
} }
systemOptions.push(cancelOption); systemOptions.push(cancelOption);
const customOptions = [
customOpenRouterOption,
customOllamaOption,
customBedrockOption,
customAzureOption,
customVertexOption
];
const systemLength = systemOptions.length; const systemLength = systemOptions.length;
if (allowNone) { if (allowNone) {
choices = [ choices = [
...systemOptions, ...systemOptions,
new inquirer.Separator('\n── Standard Models ──'), new inquirer.Separator('── Standard Models ──'),
{ name: '⚪ None (disable)', value: null }, { name: '⚪ None (disable)', value: null },
...roleChoices, ...roleChoices,
new inquirer.Separator('\n── Custom Providers ──'), new inquirer.Separator('── Custom Providers ──'),
...customProviderOptions ...customOptions
]; ];
// Adjust default index: System + Sep1 + None (+2) // Adjust default index: System + Sep1 + None (+2)
const noneOptionIndex = systemLength + 1; const noneOptionIndex = systemLength + 1;
@@ -368,10 +392,10 @@ async function runInteractiveSetup(projectRoot) {
} else { } else {
choices = [ choices = [
...systemOptions, ...systemOptions,
new inquirer.Separator('\n── Standard Models ──'), new inquirer.Separator('── Standard Models ──'),
...roleChoices, ...roleChoices,
new inquirer.Separator('\n── Custom Providers ──'), new inquirer.Separator('── Custom Providers ──'),
...customProviderOptions ...customOptions
]; ];
// Adjust default index: System + Sep (+1) // Adjust default index: System + Sep (+1)
defaultIndex = defaultIndex =

View File

@@ -509,8 +509,7 @@ function isApiKeySet(providerName, session = null, projectRoot = null) {
azure: 'AZURE_OPENAI_API_KEY', azure: 'AZURE_OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY', openrouter: 'OPENROUTER_API_KEY',
xai: 'XAI_API_KEY', xai: 'XAI_API_KEY',
vertex: 'GOOGLE_API_KEY', // Vertex uses the same key as Google vertex: 'GOOGLE_API_KEY' // Vertex uses the same key as Google
bedrock: 'AWS_ACCESS_KEY_ID' // Bedrock uses AWS credentials
// Add other providers as needed // Add other providers as needed
}; };
@@ -606,10 +605,6 @@ function getMcpApiKeyStatus(providerName, projectRoot = null) {
apiKeyToCheck = mcpEnv.GOOGLE_API_KEY; // Vertex uses Google API key apiKeyToCheck = mcpEnv.GOOGLE_API_KEY; // Vertex uses Google API key
placeholderValue = 'YOUR_GOOGLE_API_KEY_HERE'; placeholderValue = 'YOUR_GOOGLE_API_KEY_HERE';
break; break;
case 'bedrock':
apiKeyToCheck = mcpEnv.AWS_ACCESS_KEY_ID; // Bedrock uses AWS credentials
placeholderValue = 'YOUR_AWS_ACCESS_KEY_ID_HERE';
break;
default: default:
return false; // Unknown provider return false; // Unknown provider
} }

View File

@@ -525,7 +525,7 @@ async function setModel(role, modelId, options = {}) {
success: false, success: false,
error: { error: {
code: 'MODEL_NOT_FOUND_NO_HINT', code: 'MODEL_NOT_FOUND_NO_HINT',
message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter, --ollama, --bedrock, --azure, or --vertex.` message: `Model ID "${modelId}" not found in Taskmaster's supported models. If this is a custom model, please specify the provider using --openrouter or --ollama.`
} }
}; };
} }