feat: CLI & MCP progress tracking for parse-prd command (#1048)

* initial cutover

* update log to debug

* update tracker to pass units

* update test to match new base tracker format

* add streamTextService mocks

* remove unused imports

* Ensure the CLI waits for async main() completion

* refactor to reduce code duplication

* update comment

* reuse function

* ensure targetTag is defined in streaming mode

* avoid throwing inside process.exit spy

* check for null

* remove reference to generate

* fix formatting

* fix textStream assignment

* ensure no division by 0

* fix jest chalk mocks

* refactor for maintainability

* Improve bar chart calculation logic for consistent visual representation

* use custom streaming error types; fix mocks

* Update streamText extraction in parse-prd.js to match actual service response

* remove check - doesn't belong here

* update mocks

* remove streaming test that wasn't really doing anything

* add comment

* make parsing logic more DRY

* fix formatting

* Fix textStream extraction to match actual service response

* fix mock

* Add a cleanup method to ensure proper resource disposal and prevent memory leaks

* debounce progress updates to reduce UI flicker during rapid updates

* Implement timeout protection for streaming operations (60-second timeout) with automatic fallback to non-streaming mode.

* clear timeout properly

* Add a maximum buffer size limit (1MB) to prevent unbounded memory growth with very large streaming responses.

* fix formatting

* remove duplicate mock

* better docs

* fix formatting

* sanitize the dynamic property name

* Fix incorrect remaining progress calculation

* Use onError callback instead of console.warn

* Remove unused chalk import

* Add missing custom validator in fallback parsing configuration

* add custom validator parameter in fallback parsing

* chore: fix package-lock.json

* chore: large code refactor

* chore: increase timeout from 1 minute to 3 minutes

* fix: refactor and fix streaming

* Merge remote-tracking branch 'origin/next' into joedanz/parse-prd-progress

* fix: cleanup and fix unit tests

* chore: fix unit tests

* chore: fix format

* chore: run format

* chore: fix weird CI unit test error

* chore: fix format

---------

Co-authored-by: Ralph Khreish <35776126+Crunchyman-ralph@users.noreply.github.com>
This commit is contained in:
Joe Danziger
2025-08-12 16:37:07 -04:00
committed by GitHub
parent fc47714340
commit e3ed4d7c14
39 changed files with 6993 additions and 1137 deletions

View File

@@ -91,93 +91,117 @@ function _getProvider(providerName) {
// Helper function to get cost for a specific model
function _getCostForModel(providerName, modelId) {
const DEFAULT_COST = { inputCost: 0, outputCost: 0, currency: 'USD' };
if (!MODEL_MAP || !MODEL_MAP[providerName]) {
log(
'warn',
`Provider "${providerName}" not found in MODEL_MAP. Cannot determine cost for model ${modelId}.`
);
return { inputCost: 0, outputCost: 0, currency: 'USD' }; // Default to zero cost
return DEFAULT_COST;
}
const modelData = MODEL_MAP[providerName].find((m) => m.id === modelId);
if (!modelData || !modelData.cost_per_1m_tokens) {
if (!modelData?.cost_per_1m_tokens) {
log(
'debug',
`Cost data not found for model "${modelId}" under provider "${providerName}". Assuming zero cost.`
);
return { inputCost: 0, outputCost: 0, currency: 'USD' }; // Default to zero cost
return DEFAULT_COST;
}
// Ensure currency is part of the returned object, defaulting if not present
const currency = modelData.cost_per_1m_tokens.currency || 'USD';
const costs = modelData.cost_per_1m_tokens;
return {
inputCost: modelData.cost_per_1m_tokens.input || 0,
outputCost: modelData.cost_per_1m_tokens.output || 0,
currency: currency
inputCost: costs.input || 0,
outputCost: costs.output || 0,
currency: costs.currency || 'USD'
};
}
/**
* Calculate cost from token counts and cost per million
* @param {number} inputTokens - Number of input tokens
* @param {number} outputTokens - Number of output tokens
* @param {number} inputCost - Cost per million input tokens
* @param {number} outputCost - Cost per million output tokens
* @returns {number} Total calculated cost
*/
function _calculateCost(inputTokens, outputTokens, inputCost, outputCost) {
const calculatedCost =
((inputTokens || 0) / 1_000_000) * inputCost +
((outputTokens || 0) / 1_000_000) * outputCost;
return parseFloat(calculatedCost.toFixed(6));
}
// Helper function to get tag information for responses
function _getTagInfo(projectRoot) {
const DEFAULT_TAG_INFO = { currentTag: 'master', availableTags: ['master'] };
try {
if (!projectRoot) {
return { currentTag: 'master', availableTags: ['master'] };
return DEFAULT_TAG_INFO;
}
const currentTag = getCurrentTag(projectRoot);
const currentTag = getCurrentTag(projectRoot) || 'master';
const availableTags = _readAvailableTags(projectRoot);
// Read available tags from tasks.json
let availableTags = ['master']; // Default fallback
try {
const path = require('path');
const fs = require('fs');
const tasksPath = path.join(
projectRoot,
'.taskmaster',
'tasks',
'tasks.json'
);
if (fs.existsSync(tasksPath)) {
const tasksData = JSON.parse(fs.readFileSync(tasksPath, 'utf8'));
if (tasksData && typeof tasksData === 'object') {
// Check if it's tagged format (has tag-like keys with tasks arrays)
const potentialTags = Object.keys(tasksData).filter(
(key) =>
tasksData[key] &&
typeof tasksData[key] === 'object' &&
Array.isArray(tasksData[key].tasks)
);
if (potentialTags.length > 0) {
availableTags = potentialTags;
}
}
}
} catch (readError) {
// Silently fall back to default if we can't read tasks file
if (getDebugFlag()) {
log(
'debug',
`Could not read tasks file for available tags: ${readError.message}`
);
}
}
return {
currentTag: currentTag || 'master',
availableTags: availableTags
};
return { currentTag, availableTags };
} catch (error) {
if (getDebugFlag()) {
log('debug', `Error getting tag information: ${error.message}`);
}
return { currentTag: 'master', availableTags: ['master'] };
return DEFAULT_TAG_INFO;
}
}
// Extract method for reading available tags
function _readAvailableTags(projectRoot) {
const DEFAULT_TAGS = ['master'];
try {
const path = require('path');
const fs = require('fs');
const tasksPath = path.join(
projectRoot,
'.taskmaster',
'tasks',
'tasks.json'
);
if (!fs.existsSync(tasksPath)) {
return DEFAULT_TAGS;
}
const tasksData = JSON.parse(fs.readFileSync(tasksPath, 'utf8'));
if (!tasksData || typeof tasksData !== 'object') {
return DEFAULT_TAGS;
}
// Check if it's tagged format (has tag-like keys with tasks arrays)
const potentialTags = Object.keys(tasksData).filter((key) =>
_isValidTaggedTask(tasksData[key])
);
return potentialTags.length > 0 ? potentialTags : DEFAULT_TAGS;
} catch (readError) {
if (getDebugFlag()) {
log(
'debug',
`Could not read tasks file for available tags: ${readError.message}`
);
}
return DEFAULT_TAGS;
}
}
// Helper to validate tagged task structure
function _isValidTaggedTask(taskData) {
return (
taskData && typeof taskData === 'object' && Array.isArray(taskData.tasks)
);
}
// --- Configuration for Retries ---
const MAX_RETRIES = 2;
const INITIAL_RETRY_DELAY_MS = 1000;
@@ -244,6 +268,65 @@ function _extractErrorMessage(error) {
}
}
/**
* Get role configuration (provider and model) based on role type
* @param {string} role - The role ('main', 'research', 'fallback')
* @param {string} projectRoot - Project root path
* @returns {Object|null} Configuration object with provider and modelId
*/
function _getRoleConfiguration(role, projectRoot) {
const roleConfigs = {
main: {
provider: getMainProvider(projectRoot),
modelId: getMainModelId(projectRoot)
},
research: {
provider: getResearchProvider(projectRoot),
modelId: getResearchModelId(projectRoot)
},
fallback: {
provider: getFallbackProvider(projectRoot),
modelId: getFallbackModelId(projectRoot)
}
};
return roleConfigs[role] || null;
}
/**
* Get Vertex AI specific configuration
* @param {string} projectRoot - Project root path
* @param {Object} session - Session object
* @returns {Object} Vertex AI configuration parameters
*/
function _getVertexConfiguration(projectRoot, session) {
const projectId =
getVertexProjectId(projectRoot) ||
resolveEnvVariable('VERTEX_PROJECT_ID', session, projectRoot);
const location =
getVertexLocation(projectRoot) ||
resolveEnvVariable('VERTEX_LOCATION', session, projectRoot) ||
'us-central1';
const credentialsPath = resolveEnvVariable(
'GOOGLE_APPLICATION_CREDENTIALS',
session,
projectRoot
);
log(
'debug',
`Using Vertex AI configuration: Project ID=${projectId}, Location=${location}`
);
return {
projectId,
location,
...(credentialsPath && { credentials: { credentialsFromEnv: true } })
};
}
/**
* Internal helper to resolve the API key for a given provider.
* @param {string} providerName - The name of the provider (lowercase).
@@ -424,18 +507,13 @@ async function _unifiedServiceRunner(serviceType, params) {
let telemetryData = null;
try {
log('info', `New AI service call with role: ${currentRole}`);
log('debug', `New AI service call with role: ${currentRole}`);
if (currentRole === 'main') {
providerName = getMainProvider(effectiveProjectRoot);
modelId = getMainModelId(effectiveProjectRoot);
} else if (currentRole === 'research') {
providerName = getResearchProvider(effectiveProjectRoot);
modelId = getResearchModelId(effectiveProjectRoot);
} else if (currentRole === 'fallback') {
providerName = getFallbackProvider(effectiveProjectRoot);
modelId = getFallbackModelId(effectiveProjectRoot);
} else {
const roleConfig = _getRoleConfiguration(
currentRole,
effectiveProjectRoot
);
if (!roleConfig) {
log(
'error',
`Unknown role encountered in _unifiedServiceRunner: ${currentRole}`
@@ -444,6 +522,8 @@ async function _unifiedServiceRunner(serviceType, params) {
lastError || new Error(`Unknown AI role specified: ${currentRole}`);
continue;
}
providerName = roleConfig.provider;
modelId = roleConfig.modelId;
if (!providerName || !modelId) {
log(
@@ -517,41 +597,9 @@ async function _unifiedServiceRunner(serviceType, params) {
// Handle Vertex AI specific configuration
if (providerName?.toLowerCase() === 'vertex') {
// Get Vertex project ID and location
const projectId =
getVertexProjectId(effectiveProjectRoot) ||
resolveEnvVariable(
'VERTEX_PROJECT_ID',
session,
effectiveProjectRoot
);
const location =
getVertexLocation(effectiveProjectRoot) ||
resolveEnvVariable(
'VERTEX_LOCATION',
session,
effectiveProjectRoot
) ||
'us-central1';
// Get credentials path if available
const credentialsPath = resolveEnvVariable(
'GOOGLE_APPLICATION_CREDENTIALS',
session,
effectiveProjectRoot
);
// Add Vertex-specific parameters
providerSpecificParams = {
projectId,
location,
...(credentialsPath && { credentials: { credentialsFromEnv: true } })
};
log(
'debug',
`Using Vertex AI configuration: Project ID=${projectId}, Location=${location}`
providerSpecificParams = _getVertexConfiguration(
effectiveProjectRoot,
session
);
}
@@ -594,7 +642,8 @@ async function _unifiedServiceRunner(serviceType, params) {
temperature: roleParams.temperature,
messages,
...(baseURL && { baseURL }),
...(serviceType === 'generateObject' && { schema, objectName }),
...((serviceType === 'generateObject' ||
serviceType === 'streamObject') && { schema, objectName }),
...providerSpecificParams,
...restApiParams
};
@@ -635,7 +684,10 @@ async function _unifiedServiceRunner(serviceType, params) {
finalMainResult = providerResponse.text;
} else if (serviceType === 'generateObject') {
finalMainResult = providerResponse.object;
} else if (serviceType === 'streamText') {
} else if (
serviceType === 'streamText' ||
serviceType === 'streamObject'
) {
finalMainResult = providerResponse;
} else {
log(
@@ -651,7 +703,9 @@ async function _unifiedServiceRunner(serviceType, params) {
return {
mainResult: finalMainResult,
telemetryData: telemetryData,
tagInfo: tagInfo
tagInfo: tagInfo,
providerName: providerName,
modelId: modelId
};
} catch (error) {
const cleanMessage = _extractErrorMessage(error);
@@ -732,6 +786,31 @@ async function streamTextService(params) {
return _unifiedServiceRunner('streamText', combinedParams);
}
/**
* Unified service function for streaming structured objects.
* Uses Vercel AI SDK's streamObject for proper JSON streaming.
*
* @param {object} params - Parameters for the service call.
* @param {string} params.role - The initial client role ('main', 'research', 'fallback').
* @param {object} [params.session=null] - Optional MCP session object.
* @param {string} [params.projectRoot=null] - Optional project root path for .env fallback.
* @param {import('zod').ZodSchema} params.schema - The Zod schema for the expected object.
* @param {string} params.prompt - The prompt for the AI.
* @param {string} [params.systemPrompt] - Optional system prompt.
* @param {string} params.commandName - Name of the command invoking the service.
* @param {string} [params.outputType='cli'] - 'cli' or 'mcp'.
* @returns {Promise<object>} Result object containing the stream and usage data.
*/
async function streamObjectService(params) {
const defaults = { outputType: 'cli' };
const combinedParams = { ...defaults, ...params };
// Stream object requires a schema
if (!combinedParams.schema) {
throw new Error('streamObjectService requires a schema parameter');
}
return _unifiedServiceRunner('streamObject', combinedParams);
}
/**
* Unified service function for generating structured objects.
* Handles client retrieval, retries, and fallback sequence.
@@ -792,9 +871,12 @@ async function logAiUsage({
modelId
);
const totalCost =
((inputTokens || 0) / 1_000_000) * inputCost +
((outputTokens || 0) / 1_000_000) * outputCost;
const totalCost = _calculateCost(
inputTokens,
outputTokens,
inputCost,
outputCost
);
const telemetryData = {
timestamp,
@@ -805,7 +887,7 @@ async function logAiUsage({
inputTokens: inputTokens || 0,
outputTokens: outputTokens || 0,
totalTokens,
totalCost: parseFloat(totalCost.toFixed(6)),
totalCost,
currency // Add currency to the telemetry data
};
@@ -828,6 +910,7 @@ async function logAiUsage({
export {
generateTextService,
streamTextService,
streamObjectService,
generateObjectService,
logAiUsage
};