fix(update): pass projectRoot through update command flow

Modified ai-services-unified.js, update.js tool, and update-tasks.js direct function to correctly pass projectRoot. This enables the .env file API key fallback mechanism for the update command when running via MCP, ensuring consistent key resolution with the CLI context.
This commit is contained in:
Eyal Toledano
2025-05-01 13:45:11 -04:00
parent 40df57f969
commit 2a07d366be
4 changed files with 48 additions and 25 deletions

View File

@@ -20,7 +20,7 @@ import { createLogWrapper } from '../../tools/utils.js';
*/ */
export async function updateTasksDirect(args, log, context = {}) { export async function updateTasksDirect(args, log, context = {}) {
const { session } = context; // Extract session const { session } = context; // Extract session
const { tasksJsonPath, from, prompt, research } = args; const { tasksJsonPath, from, prompt, research, projectRoot } = args;
// Create the standard logger wrapper // Create the standard logger wrapper
const logWrapper = { const logWrapper = {
@@ -85,21 +85,23 @@ export async function updateTasksDirect(args, log, context = {}) {
const useResearch = research === true; const useResearch = research === true;
// --- End Input Validation --- // --- End Input Validation ---
log.info(`Updating tasks from ID ${fromId}. Research: ${useResearch}`); log.info(
`Updating tasks from ID ${fromId}. Research: ${useResearch}. Project Root: ${projectRoot}`
);
enableSilentMode(); // Enable silent mode enableSilentMode(); // Enable silent mode
try { try {
// Create logger wrapper using the utility // Create logger wrapper using the utility
const mcpLog = createLogWrapper(log); const mcpLog = createLogWrapper(log);
// Execute core updateTasks function, passing session context // Execute core updateTasks function, passing session context AND projectRoot
await updateTasks( await updateTasks(
tasksJsonPath, tasksJsonPath,
fromId, fromId,
prompt, prompt,
useResearch, useResearch,
// Pass context with logger wrapper and session // Pass context with logger wrapper, session, AND projectRoot
{ mcpLog, session }, { mcpLog, session, projectRoot },
'json' // Explicitly request JSON format for MCP 'json' // Explicitly request JSON format for MCP
); );

View File

@@ -70,7 +70,8 @@ export function registerUpdateTool(server) {
tasksJsonPath: tasksJsonPath, tasksJsonPath: tasksJsonPath,
from: args.from, from: args.from,
prompt: args.prompt, prompt: args.prompt,
research: args.research research: args.research,
projectRoot: rootFolder
}, },
log, log,
{ session } { session }

View File

@@ -16,7 +16,7 @@ import {
getFallbackModelId, getFallbackModelId,
getParametersForRole getParametersForRole
} from './config-manager.js'; } from './config-manager.js';
import { log, resolveEnvVariable } from './utils.js'; import { log, resolveEnvVariable, findProjectRoot } from './utils.js';
import * as anthropic from '../../src/ai-providers/anthropic.js'; import * as anthropic from '../../src/ai-providers/anthropic.js';
import * as perplexity from '../../src/ai-providers/perplexity.js'; import * as perplexity from '../../src/ai-providers/perplexity.js';
@@ -136,10 +136,11 @@ function _extractErrorMessage(error) {
* Internal helper to resolve the API key for a given provider. * Internal helper to resolve the API key for a given provider.
* @param {string} providerName - The name of the provider (lowercase). * @param {string} providerName - The name of the provider (lowercase).
* @param {object|null} session - Optional MCP session object. * @param {object|null} session - Optional MCP session object.
* @param {string|null} projectRoot - Optional project root path for .env fallback.
* @returns {string|null} The API key or null if not found/needed. * @returns {string|null} The API key or null if not found/needed.
* @throws {Error} If a required API key is missing. * @throws {Error} If a required API key is missing.
*/ */
function _resolveApiKey(providerName, session) { function _resolveApiKey(providerName, session, projectRoot = null) {
const keyMap = { const keyMap = {
openai: 'OPENAI_API_KEY', openai: 'OPENAI_API_KEY',
anthropic: 'ANTHROPIC_API_KEY', anthropic: 'ANTHROPIC_API_KEY',
@@ -163,10 +164,10 @@ function _resolveApiKey(providerName, session) {
); );
} }
const apiKey = resolveEnvVariable(envVarName, session); const apiKey = resolveEnvVariable(envVarName, session, projectRoot);
if (!apiKey) { if (!apiKey) {
throw new Error( throw new Error(
`Required API key ${envVarName} for provider '${providerName}' is not set in environment or session.` `Required API key ${envVarName} for provider '${providerName}' is not set in environment, session, or .env file.`
); );
} }
return apiKey; return apiKey;
@@ -241,27 +242,35 @@ async function _attemptProviderCallWithRetries(
* Base logic for unified service functions. * Base logic for unified service functions.
* @param {string} serviceType - Type of service ('generateText', 'streamText', 'generateObject'). * @param {string} serviceType - Type of service ('generateText', 'streamText', 'generateObject').
* @param {object} params - Original parameters passed to the service function. * @param {object} params - Original parameters passed to the service function.
* @param {string} [params.projectRoot] - Optional project root path.
* @returns {Promise<any>} Result from the underlying provider call. * @returns {Promise<any>} Result from the underlying provider call.
*/ */
async function _unifiedServiceRunner(serviceType, params) { async function _unifiedServiceRunner(serviceType, params) {
const { const {
role: initialRole, role: initialRole,
session, session,
projectRoot,
systemPrompt, systemPrompt,
prompt, prompt,
schema, schema,
objectName, objectName,
...restApiParams ...restApiParams
} = params; } = params;
log('info', `${serviceType}Service called`, { role: initialRole }); log('info', `${serviceType}Service called`, {
role: initialRole,
projectRoot
});
// Determine the effective project root (passed in or detected)
const effectiveProjectRoot = projectRoot || findProjectRoot();
let sequence; let sequence;
if (initialRole === 'main') { if (initialRole === 'main') {
sequence = ['main', 'fallback', 'research']; sequence = ['main', 'fallback', 'research'];
} else if (initialRole === 'fallback') {
sequence = ['fallback', 'research'];
} else if (initialRole === 'research') { } else if (initialRole === 'research') {
sequence = ['research', 'fallback']; sequence = ['research', 'fallback', 'main'];
} else if (initialRole === 'fallback') {
sequence = ['fallback', 'main', 'research'];
} else { } else {
log( log(
'warn', 'warn',
@@ -281,16 +290,16 @@ async function _unifiedServiceRunner(serviceType, params) {
log('info', `New AI service call with role: ${currentRole}`); log('info', `New AI service call with role: ${currentRole}`);
// 1. Get Config: Provider, Model, Parameters for the current role // 1. Get Config: Provider, Model, Parameters for the current role
// Call individual getters based on the current role // Pass effectiveProjectRoot to config getters
if (currentRole === 'main') { if (currentRole === 'main') {
providerName = getMainProvider(); providerName = getMainProvider(effectiveProjectRoot);
modelId = getMainModelId(); modelId = getMainModelId(effectiveProjectRoot);
} else if (currentRole === 'research') { } else if (currentRole === 'research') {
providerName = getResearchProvider(); providerName = getResearchProvider(effectiveProjectRoot);
modelId = getResearchModelId(); modelId = getResearchModelId(effectiveProjectRoot);
} else if (currentRole === 'fallback') { } else if (currentRole === 'fallback') {
providerName = getFallbackProvider(); providerName = getFallbackProvider(effectiveProjectRoot);
modelId = getFallbackModelId(); modelId = getFallbackModelId(effectiveProjectRoot);
} else { } else {
log( log(
'error', 'error',
@@ -314,7 +323,8 @@ async function _unifiedServiceRunner(serviceType, params) {
continue; continue;
} }
roleParams = getParametersForRole(currentRole); // Pass effectiveProjectRoot to getParametersForRole
roleParams = getParametersForRole(currentRole, effectiveProjectRoot);
// 2. Get Provider Function Set // 2. Get Provider Function Set
providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()]; providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()];
@@ -345,7 +355,12 @@ async function _unifiedServiceRunner(serviceType, params) {
} }
// 3. Resolve API Key (will throw if required and missing) // 3. Resolve API Key (will throw if required and missing)
apiKey = _resolveApiKey(providerName?.toLowerCase(), session); // Pass effectiveProjectRoot to _resolveApiKey
apiKey = _resolveApiKey(
providerName?.toLowerCase(),
session,
effectiveProjectRoot
);
// 4. Construct Messages Array // 4. Construct Messages Array
const messages = []; const messages = [];
@@ -443,6 +458,7 @@ async function _unifiedServiceRunner(serviceType, params) {
* @param {object} params - Parameters for the service call. * @param {object} params - Parameters for the service call.
* @param {string} params.role - The initial client role ('main', 'research', 'fallback'). * @param {string} params.role - The initial client role ('main', 'research', 'fallback').
* @param {object} [params.session=null] - Optional MCP session object. * @param {object} [params.session=null] - Optional MCP session object.
* @param {string} [params.projectRoot=null] - Optional project root path for .env fallback.
* @param {string} params.prompt - The prompt for the AI. * @param {string} params.prompt - The prompt for the AI.
* @param {string} [params.systemPrompt] - Optional system prompt. * @param {string} [params.systemPrompt] - Optional system prompt.
* // Other specific generateText params can be included here. * // Other specific generateText params can be included here.
@@ -459,6 +475,7 @@ async function generateTextService(params) {
* @param {object} params - Parameters for the service call. * @param {object} params - Parameters for the service call.
* @param {string} params.role - The initial client role ('main', 'research', 'fallback'). * @param {string} params.role - The initial client role ('main', 'research', 'fallback').
* @param {object} [params.session=null] - Optional MCP session object. * @param {object} [params.session=null] - Optional MCP session object.
* @param {string} [params.projectRoot=null] - Optional project root path for .env fallback.
* @param {string} params.prompt - The prompt for the AI. * @param {string} params.prompt - The prompt for the AI.
* @param {string} [params.systemPrompt] - Optional system prompt. * @param {string} [params.systemPrompt] - Optional system prompt.
* // Other specific streamText params can be included here. * // Other specific streamText params can be included here.
@@ -475,6 +492,7 @@ async function streamTextService(params) {
* @param {object} params - Parameters for the service call. * @param {object} params - Parameters for the service call.
* @param {string} params.role - The initial client role ('main', 'research', 'fallback'). * @param {string} params.role - The initial client role ('main', 'research', 'fallback').
* @param {object} [params.session=null] - Optional MCP session object. * @param {object} [params.session=null] - Optional MCP session object.
* @param {string} [params.projectRoot=null] - Optional project root path for .env fallback.
* @param {import('zod').ZodSchema} params.schema - The Zod schema for the expected object. * @param {import('zod').ZodSchema} params.schema - The Zod schema for the expected object.
* @param {string} params.prompt - The prompt for the AI. * @param {string} params.prompt - The prompt for the AI.
* @param {string} [params.systemPrompt] - Optional system prompt. * @param {string} [params.systemPrompt] - Optional system prompt.

View File

@@ -21,6 +21,7 @@ import {
import { getDebugFlag } from '../config-manager.js'; import { getDebugFlag } from '../config-manager.js';
import generateTaskFiles from './generate-task-files.js'; import generateTaskFiles from './generate-task-files.js';
import { generateTextService } from '../ai-services-unified.js'; import { generateTextService } from '../ai-services-unified.js';
import { getModelConfiguration } from './models.js';
// Zod schema for validating the structure of tasks AFTER parsing // Zod schema for validating the structure of tasks AFTER parsing
const updatedTaskSchema = z const updatedTaskSchema = z
@@ -173,7 +174,7 @@ async function updateTasks(
context = {}, context = {},
outputFormat = 'text' // Default to text for CLI outputFormat = 'text' // Default to text for CLI
) { ) {
const { session, mcpLog } = context; const { session, mcpLog, projectRoot } = context;
// Use mcpLog if available, otherwise use the imported consoleLog function // Use mcpLog if available, otherwise use the imported consoleLog function
const logFn = mcpLog || consoleLog; const logFn = mcpLog || consoleLog;
// Flag to easily check which logger type we have // Flag to easily check which logger type we have
@@ -312,7 +313,8 @@ The changes described in the prompt should be applied to ALL tasks in the list.`
prompt: userPrompt, prompt: userPrompt,
systemPrompt: systemPrompt, systemPrompt: systemPrompt,
role, role,
session session,
projectRoot
}); });
if (isMCP) logFn.info('Successfully received text response'); if (isMCP) logFn.info('Successfully received text response');
else else