feat(ai): Integrate OpenAI provider and enhance model config

- Add OpenAI provider implementation using @ai-sdk/openai.\n- Update `models` command/tool to display API key status for configured providers.\n- Implement model-specific `maxTokens` override logic in `config-manager.js` using `supported-models.json`.\n- Improve AI error message parsing in `ai-services-unified.js` for better clarity.
This commit is contained in:
Eyal Toledano
2025-04-27 03:56:23 -04:00
parent cbc3576642
commit 49e1137eab
21 changed files with 1350 additions and 662 deletions

View File

@@ -25,7 +25,8 @@ import { log, resolveEnvVariable } from './utils.js';
import * as anthropic from '../../src/ai-providers/anthropic.js';
import * as perplexity from '../../src/ai-providers/perplexity.js';
import * as google from '../../src/ai-providers/google.js'; // Import Google provider
// TODO: Import other provider modules when implemented (openai, ollama, etc.)
import * as openai from '../../src/ai-providers/openai.js'; // ADD: Import OpenAI provider
// TODO: Import other provider modules when implemented (ollama, etc.)
// --- Provider Function Map ---
// Maps provider names (lowercase) to their respective service functions
@@ -47,8 +48,14 @@ const PROVIDER_FUNCTIONS = {
generateText: google.generateGoogleText,
streamText: google.streamGoogleText,
generateObject: google.generateGoogleObject
},
openai: {
// ADD: OpenAI entry
generateText: openai.generateOpenAIText,
streamText: openai.streamOpenAIText,
generateObject: openai.generateOpenAIObject
}
// TODO: Add entries for openai, ollama, etc. when implemented
// TODO: Add entries for ollama, etc. when implemented
};
// --- Configuration for Retries ---
@@ -71,6 +78,54 @@ function isRetryableError(error) {
);
}
/**
* Extracts a user-friendly error message from a potentially complex AI error object.
* Prioritizes nested messages and falls back to the top-level message.
* @param {Error | object | any} error - The error object.
* @returns {string} A concise error message.
*/
function _extractErrorMessage(error) {
try {
// Attempt 1: Look for Vercel SDK specific nested structure (common)
if (error?.data?.error?.message) {
return error.data.error.message;
}
// Attempt 2: Look for nested error message directly in the error object
if (error?.error?.message) {
return error.error.message;
}
// Attempt 3: Look for nested error message in response body if it's JSON string
if (typeof error?.responseBody === 'string') {
try {
const body = JSON.parse(error.responseBody);
if (body?.error?.message) {
return body.error.message;
}
} catch (parseError) {
// Ignore if responseBody is not valid JSON
}
}
// Attempt 4: Use the top-level message if it exists
if (typeof error?.message === 'string' && error.message) {
return error.message;
}
// Attempt 5: Handle simple string errors
if (typeof error === 'string') {
return error;
}
// Fallback
return 'An unknown AI service error occurred.';
} catch (e) {
// Safety net
return 'Failed to extract error message.';
}
}
/**
* Internal helper to resolve the API key for a given provider.
* @param {string} providerName - The name of the provider (lowercase).
@@ -87,8 +142,7 @@ function _resolveApiKey(providerName, session) {
mistral: 'MISTRAL_API_KEY',
azure: 'AZURE_OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY',
xai: 'XAI_API_KEY',
ollama: 'OLLAMA_API_KEY'
xai: 'XAI_API_KEY'
};
// Double check this -- I have had to use an api key for ollama in the past
@@ -211,6 +265,8 @@ async function _unifiedServiceRunner(serviceType, params) {
}
let lastError = null;
let lastCleanErrorMessage =
'AI service call failed for all configured roles.';
for (const currentRole of sequence) {
let providerName, modelId, apiKey, roleParams, providerFnSet, providerApiFn;
@@ -344,23 +400,21 @@ async function _unifiedServiceRunner(serviceType, params) {
return result; // Return original result for other cases
} catch (error) {
const cleanMessage = _extractErrorMessage(error); // Extract clean message
log(
'error', // Log as error since this role attempt failed
`Service call failed for role ${currentRole} (Provider: ${providerName || 'unknown'}): ${error.message}`
`Service call failed for role ${currentRole} (Provider: ${providerName || 'unknown'}): ${cleanMessage}` // Log the clean message
);
lastError = error; // Store the error to throw if all roles fail
// Log reason and continue (handled within the loop now)
lastError = error; // Store the original error for potential debugging
lastCleanErrorMessage = cleanMessage; // Store the clean message for final throw
// Continue to the next role in the sequence
}
}
// If loop completes, all roles failed
log('error', `All roles in the sequence [${sequence.join(', ')}] failed.`);
throw (
lastError ||
new Error(
`AI service call (${serviceType}) failed for all configured roles in the sequence.`
)
);
// Throw a new error with the cleaner message from the last failure
throw new Error(lastCleanErrorMessage);
}
/**

File diff suppressed because it is too large Load Diff

View File

@@ -255,8 +255,6 @@ function getModelConfigForRole(role, explicitRoot = null) {
const config = getConfig(explicitRoot);
const roleConfig = config?.models?.[role];
if (!roleConfig) {
// This shouldn't happen if _loadAndValidateConfig ensures defaults
// But as a safety net, log and return defaults
log(
'warn',
`No model configuration found for role: ${role}. Returning default.`
@@ -363,16 +361,64 @@ function getOllamaBaseUrl(explicitRoot = null) {
}
/**
* Gets model parameters (maxTokens, temperature) for a specific role.
* Gets model parameters (maxTokens, temperature) for a specific role,
* considering model-specific overrides from supported-models.json.
* @param {string} role - The role ('main', 'research', 'fallback').
* @param {string|null} explicitRoot - Optional explicit path to the project root.
* @returns {{maxTokens: number, temperature: number}}
*/
function getParametersForRole(role, explicitRoot = null) {
const roleConfig = getModelConfigForRole(role, explicitRoot);
const roleMaxTokens = roleConfig.maxTokens;
const roleTemperature = roleConfig.temperature;
const modelId = roleConfig.modelId;
const providerName = roleConfig.provider;
let effectiveMaxTokens = roleMaxTokens; // Start with the role's default
try {
// Find the model definition in MODEL_MAP
const providerModels = MODEL_MAP[providerName];
if (providerModels && Array.isArray(providerModels)) {
const modelDefinition = providerModels.find((m) => m.id === modelId);
// Check if a model-specific max_tokens is defined and valid
if (
modelDefinition &&
typeof modelDefinition.max_tokens === 'number' &&
modelDefinition.max_tokens > 0
) {
const modelSpecificMaxTokens = modelDefinition.max_tokens;
// Use the minimum of the role default and the model specific limit
effectiveMaxTokens = Math.min(roleMaxTokens, modelSpecificMaxTokens);
log(
'debug',
`Applying model-specific max_tokens (${modelSpecificMaxTokens}) for ${modelId}. Effective limit: ${effectiveMaxTokens}`
);
} else {
log(
'debug',
`No valid model-specific max_tokens override found for ${modelId}. Using role default: ${roleMaxTokens}`
);
}
} else {
log(
'debug',
`No model definitions found for provider ${providerName} in MODEL_MAP. Using role default maxTokens: ${roleMaxTokens}`
);
}
} catch (lookupError) {
log(
'warn',
`Error looking up model-specific max_tokens for ${modelId}: ${lookupError.message}. Using role default: ${roleMaxTokens}`
);
// Fallback to role default on error
effectiveMaxTokens = roleMaxTokens;
}
return {
maxTokens: roleConfig.maxTokens,
temperature: roleConfig.temperature
maxTokens: effectiveMaxTokens,
temperature: roleTemperature
};
}
@@ -385,16 +431,19 @@ function getParametersForRole(role, explicitRoot = null) {
*/
function isApiKeySet(providerName, session = null) {
// Define the expected environment variable name for each provider
if (providerName?.toLowerCase() === 'ollama') {
return true; // Indicate key status is effectively "OK"
}
const keyMap = {
openai: 'OPENAI_API_KEY',
anthropic: 'ANTHROPIC_API_KEY',
google: 'GOOGLE_API_KEY',
perplexity: 'PERPLEXITY_API_KEY',
mistral: 'MISTRAL_API_KEY',
azure: 'AZURE_OPENAI_API_KEY', // Azure needs endpoint too, but key presence is a start
azure: 'AZURE_OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY',
xai: 'XAI_API_KEY',
ollama: 'OLLAMA_API_KEY'
xai: 'XAI_API_KEY'
// Add other providers as needed
};
@@ -405,8 +454,15 @@ function isApiKeySet(providerName, session = null) {
}
const envVarName = keyMap[providerKey];
// Use resolveEnvVariable to check both process.env and session.env
return !!resolveEnvVariable(envVarName, session);
const apiKeyValue = resolveEnvVariable(envVarName, session);
// Check if the key exists, is not empty, and is not a placeholder
return (
apiKeyValue &&
apiKeyValue.trim() !== '' &&
!/YOUR_.*_API_KEY_HERE/.test(apiKeyValue) && // General placeholder check
!apiKeyValue.includes('KEY_HERE')
); // Another common placeholder pattern
}
/**
@@ -482,7 +538,7 @@ function getMcpApiKeyStatus(providerName, projectRoot = null) {
return false; // Unknown provider
}
return !!apiKeyToCheck && apiKeyToCheck !== placeholderValue;
return !!apiKeyToCheck && !/KEY_HERE$/.test(apiKeyToCheck);
} catch (error) {
console.error(
chalk.red(`Error reading or parsing .cursor/mcp.json: ${error.message}`)
@@ -589,6 +645,14 @@ function isConfigFilePresent(explicitRoot = null) {
return fs.existsSync(configPath);
}
/**
* Gets a list of all provider names defined in the MODEL_MAP.
* @returns {string[]} An array of provider names.
*/
function getAllProviders() {
return Object.keys(MODEL_MAP || {});
}
export {
// Core config access
getConfig,
@@ -628,5 +692,8 @@ export {
// API Key Checkers (still relevant)
isApiKeySet,
getMcpApiKeyStatus
getMcpApiKeyStatus,
// ADD: Function to get all provider names
getAllProviders
};

View File

@@ -4,25 +4,29 @@
"id": "claude-3-7-sonnet-20250219",
"swe_score": 0.623,
"cost_per_1m_tokens": { "input": 3.0, "output": 15.0 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 120000
},
{
"id": "claude-3-5-sonnet-20241022",
"swe_score": 0.49,
"cost_per_1m_tokens": { "input": 3.0, "output": 15.0 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 64000
},
{
"id": "claude-3-5-haiku-20241022",
"swe_score": 0.406,
"cost_per_1m_tokens": { "input": 0.8, "output": 4.0 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 64000
},
{
"id": "claude-3-opus-20240229",
"swe_score": 0,
"cost_per_1m_tokens": { "input": 15, "output": 75 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 64000
}
],
"openai": [
@@ -48,7 +52,8 @@
"id": "o3-mini",
"swe_score": 0.493,
"cost_per_1m_tokens": { "input": 1.1, "output": 4.4 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 100000
},
{
"id": "o4-mini",
@@ -68,12 +73,6 @@
"cost_per_1m_tokens": { "input": 150.0, "output": 600.0 },
"allowed_roles": ["main", "fallback"]
},
{
"id": "gpt-4-1",
"swe_score": 0.55,
"cost_per_1m_tokens": { "input": 2.0, "output": 8.0 },
"allowed_roles": ["main", "fallback"]
},
{
"id": "gpt-4-5-preview",
"swe_score": 0.38,
@@ -148,31 +147,36 @@
"id": "sonar-pro",
"swe_score": 0,
"cost_per_1m_tokens": { "input": 3, "output": 15 },
"allowed_roles": ["research"]
"allowed_roles": ["research"],
"max_tokens": 8700
},
{
"id": "sonar",
"swe_score": 0,
"cost_per_1m_tokens": { "input": 1, "output": 1 },
"allowed_roles": ["research"]
"allowed_roles": ["research"],
"max_tokens": 8700
},
{
"id": "deep-research",
"swe_score": 0.211,
"cost_per_1m_tokens": { "input": 2, "output": 8 },
"allowed_roles": ["research"]
"allowed_roles": ["research"],
"max_tokens": 8700
},
{
"id": "sonar-reasoning-pro",
"swe_score": 0.211,
"cost_per_1m_tokens": { "input": 2, "output": 8 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 8700
},
{
"id": "sonar-reasoning",
"swe_score": 0.211,
"cost_per_1m_tokens": { "input": 1, "output": 5 },
"allowed_roles": ["main", "fallback"]
"allowed_roles": ["main", "fallback"],
"max_tokens": 8700
}
],
"ollama": [

View File

@@ -17,7 +17,8 @@ import {
getMcpApiKeyStatus,
getConfig,
writeConfig,
isConfigFilePresent
isConfigFilePresent,
getAllProviders
} from '../config-manager.js';
/**
@@ -382,4 +383,61 @@ async function setModel(role, modelId, options = {}) {
}
}
export { getModelConfiguration, getAvailableModelsList, setModel };
/**
* Get API key status for all known providers.
* @param {Object} [options] - Options for the operation
* @param {Object} [options.session] - Session object containing environment variables (for MCP)
* @param {Function} [options.mcpLog] - MCP logger object (for MCP)
* @param {string} [options.projectRoot] - Project root directory
* @returns {Object} RESTful response with API key status report
*/
async function getApiKeyStatusReport(options = {}) {
const { mcpLog, projectRoot, session } = options;
const report = (level, ...args) => {
if (mcpLog && typeof mcpLog[level] === 'function') {
mcpLog[level](...args);
}
};
try {
const providers = getAllProviders();
const providersToCheck = providers.filter(
(p) => p.toLowerCase() !== 'ollama'
); // Ollama is not a provider, it's a service, doesn't need an api key usually
const statusReport = providersToCheck.map((provider) => {
// Use provided projectRoot for MCP status check
const cliOk = isApiKeySet(provider, session); // Pass session for CLI check too
const mcpOk = getMcpApiKeyStatus(provider, projectRoot);
return {
provider,
cli: cliOk,
mcp: mcpOk
};
});
report('info', 'Successfully generated API key status report.');
return {
success: true,
data: {
report: statusReport,
message: 'API key status report generated.'
}
};
} catch (error) {
report('error', `Error generating API key status report: ${error.message}`);
return {
success: false,
error: {
code: 'API_KEY_STATUS_ERROR',
message: error.message
}
};
}
}
export {
getModelConfiguration,
getAvailableModelsList,
setModel,
getApiKeyStatusReport
};

View File

@@ -1814,6 +1814,210 @@ async function confirmTaskOverwrite(tasksPath) {
return answer.toLowerCase() === 'y' || answer.toLowerCase() === 'yes';
}
/**
* Displays the API key status for different providers.
* @param {Array<{provider: string, cli: boolean, mcp: boolean}>} statusReport - The report generated by getApiKeyStatusReport.
*/
function displayApiKeyStatus(statusReport) {
if (!statusReport || statusReport.length === 0) {
console.log(chalk.yellow('No API key status information available.'));
return;
}
const table = new Table({
head: [
chalk.cyan('Provider'),
chalk.cyan('CLI Key (.env)'),
chalk.cyan('MCP Key (mcp.json)')
],
colWidths: [15, 20, 25],
chars: { mid: '', 'left-mid': '', 'mid-mid': '', 'right-mid': '' }
});
statusReport.forEach(({ provider, cli, mcp }) => {
const cliStatus = cli ? chalk.green('✅ Found') : chalk.red('❌ Missing');
const mcpStatus = mcp ? chalk.green('✅ Found') : chalk.red('❌ Missing');
// Capitalize provider name for display
const providerName = provider.charAt(0).toUpperCase() + provider.slice(1);
table.push([providerName, cliStatus, mcpStatus]);
});
console.log(chalk.bold('\n🔑 API Key Status:'));
console.log(table.toString());
console.log(
chalk.gray(
' Note: Some providers (e.g., Azure, Ollama) may require additional endpoint configuration in .taskmasterconfig.'
)
);
}
// --- Formatting Helpers (Potentially move some to utils.js if reusable) ---
const formatSweScoreWithTertileStars = (score, allModels) => {
// ... (Implementation from previous version or refine) ...
if (score === null || score === undefined || score <= 0) return 'N/A';
const formattedPercentage = `${(score * 100).toFixed(1)}%`;
const validScores = allModels
.map((m) => m.sweScore)
.filter((s) => s !== null && s !== undefined && s > 0);
const sortedScores = [...validScores].sort((a, b) => b - a);
const n = sortedScores.length;
let stars = chalk.gray('☆☆☆');
if (n > 0) {
const topThirdIndex = Math.max(0, Math.floor(n / 3) - 1);
const midThirdIndex = Math.max(0, Math.floor((2 * n) / 3) - 1);
if (score >= sortedScores[topThirdIndex]) stars = chalk.yellow('★★★');
else if (score >= sortedScores[midThirdIndex])
stars = chalk.yellow('★★') + chalk.gray('☆');
else stars = chalk.yellow('★') + chalk.gray('☆☆');
}
return `${formattedPercentage} ${stars}`;
};
const formatCost = (costObj) => {
// ... (Implementation from previous version or refine) ...
if (!costObj) return 'N/A';
if (costObj.input === 0 && costObj.output === 0) {
return chalk.green('Free');
}
const formatSingleCost = (costValue) => {
if (costValue === null || costValue === undefined) return 'N/A';
const isInteger = Number.isInteger(costValue);
return `$${costValue.toFixed(isInteger ? 0 : 2)}`;
};
return `${formatSingleCost(costObj.input)} in, ${formatSingleCost(costObj.output)} out`;
};
// --- Display Functions ---
/**
* Displays the currently configured active models.
* @param {ConfigData} configData - The active configuration data.
* @param {AvailableModel[]} allAvailableModels - Needed for SWE score tertiles.
*/
function displayModelConfiguration(configData, allAvailableModels = []) {
console.log(chalk.cyan.bold('\nActive Model Configuration:'));
const active = configData.activeModels;
const activeTable = new Table({
head: [
'Role',
'Provider',
'Model ID',
'SWE Score',
'Cost ($/1M tkns)'
// 'API Key Status' // Removed, handled by separate displayApiKeyStatus
].map((h) => chalk.cyan.bold(h)),
colWidths: [10, 14, 30, 18, 20 /*, 28 */], // Adjusted widths
style: { head: ['cyan', 'bold'] }
});
activeTable.push([
chalk.white('Main'),
active.main.provider,
active.main.modelId,
formatSweScoreWithTertileStars(active.main.sweScore, allAvailableModels),
formatCost(active.main.cost)
// getCombinedStatus(active.main.keyStatus) // Removed
]);
activeTable.push([
chalk.white('Research'),
active.research.provider,
active.research.modelId,
formatSweScoreWithTertileStars(
active.research.sweScore,
allAvailableModels
),
formatCost(active.research.cost)
// getCombinedStatus(active.research.keyStatus) // Removed
]);
if (active.fallback && active.fallback.provider && active.fallback.modelId) {
activeTable.push([
chalk.white('Fallback'),
active.fallback.provider,
active.fallback.modelId,
formatSweScoreWithTertileStars(
active.fallback.sweScore,
allAvailableModels
),
formatCost(active.fallback.cost)
// getCombinedStatus(active.fallback.keyStatus) // Removed
]);
} else {
activeTable.push([
chalk.white('Fallback'),
chalk.gray('-'),
chalk.gray('(Not Set)'),
chalk.gray('-'),
chalk.gray('-')
// chalk.gray('-') // Removed
]);
}
console.log(activeTable.toString());
}
/**
* Displays the list of available models not currently configured.
* @param {AvailableModel[]} availableModels - List of available models.
*/
function displayAvailableModels(availableModels) {
if (!availableModels || availableModels.length === 0) {
console.log(
chalk.gray('\n(No other models available or all are configured)')
);
return;
}
console.log(chalk.cyan.bold('\nOther Available Models:'));
const availableTable = new Table({
head: ['Provider', 'Model ID', 'SWE Score', 'Cost ($/1M tkns)'].map((h) =>
chalk.cyan.bold(h)
),
colWidths: [15, 40, 18, 25],
style: { head: ['cyan', 'bold'] }
});
availableModels.forEach((model) => {
availableTable.push([
model.provider,
model.modelId,
formatSweScoreWithTertileStars(model.sweScore, availableModels), // Pass itself for comparison
formatCost(model.cost)
]);
});
console.log(availableTable.toString());
// --- Suggested Actions Section (moved here from models command) ---
console.log(
boxen(
chalk.white.bold('Next Steps:') +
'\n' +
chalk.cyan(
`1. Set main model: ${chalk.yellow('task-master models --set-main <model_id>')}`
) +
'\n' +
chalk.cyan(
`2. Set research model: ${chalk.yellow('task-master models --set-research <model_id>')}`
) +
'\n' +
chalk.cyan(
`3. Set fallback model: ${chalk.yellow('task-master models --set-fallback <model_id>')}`
) +
'\n' +
chalk.cyan(
`4. Run interactive setup: ${chalk.yellow('task-master models --setup')}`
),
{
padding: 1,
borderColor: 'yellow',
borderStyle: 'round',
margin: { top: 1 }
}
)
);
}
// Export UI functions
export {
displayBanner,
@@ -1828,5 +2032,8 @@ export {
displayTaskById,
displayComplexityReport,
generateComplexityAnalysisPrompt,
confirmTaskOverwrite
confirmTaskOverwrite,
displayApiKeyStatus,
displayModelConfiguration,
displayAvailableModels
};