mirror of
https://github.com/eyaltoledano/claude-task-master.git
synced 2026-01-30 06:12:05 +00:00
607 lines
16 KiB
JavaScript
607 lines
16 KiB
JavaScript
import * as ai from 'ai';
|
|
import { jsonrepair } from 'jsonrepair';
|
|
import { EnvHttpProxyAgent } from 'undici';
|
|
import { isProxyEnabled } from '../../scripts/modules/config-manager.js';
|
|
import { findProjectRoot, log } from '../../scripts/modules/utils.js';
|
|
import { getAITelemetryConfig, hashProjectRoot } from '../telemetry/sentry.js';
|
|
|
|
const {
|
|
JSONParseError,
|
|
NoObjectGeneratedError,
|
|
generateObject,
|
|
generateText,
|
|
streamObject,
|
|
streamText,
|
|
zodSchema
|
|
} = ai;
|
|
|
|
const jsonSchemaHelper = ai.jsonSchema;
|
|
|
|
const INTEGER_CONSTRAINT_KEYS = new Set([
|
|
'minimum',
|
|
'maximum',
|
|
'exclusiveMinimum',
|
|
'exclusiveMaximum'
|
|
]);
|
|
|
|
const SCHEMA_OBJECT_KEYS = [
|
|
'additionalProperties',
|
|
'contains',
|
|
'if',
|
|
'then',
|
|
'else',
|
|
'not',
|
|
'propertyNames'
|
|
];
|
|
|
|
const SCHEMA_ARRAY_KEYS = ['allOf', 'anyOf', 'oneOf', 'prefixItems'];
|
|
|
|
const SCHEMA_RECORD_KEYS = [
|
|
'definitions',
|
|
'$defs',
|
|
'dependentSchemas',
|
|
'patternProperties',
|
|
'properties'
|
|
];
|
|
|
|
const isIntegerType = (type) => {
|
|
if (!type) {
|
|
return false;
|
|
}
|
|
if (Array.isArray(type)) {
|
|
return type.includes('integer');
|
|
}
|
|
return type === 'integer';
|
|
};
|
|
|
|
const sanitizeIntegerConstraints = (schema) => {
|
|
if (!schema || typeof schema !== 'object') {
|
|
return schema;
|
|
}
|
|
|
|
if (Array.isArray(schema)) {
|
|
return schema.map(sanitizeIntegerConstraints);
|
|
}
|
|
|
|
const next = { ...schema };
|
|
|
|
if (isIntegerType(next.type)) {
|
|
for (const key of INTEGER_CONSTRAINT_KEYS) {
|
|
if (key in next) {
|
|
delete next[key];
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const key of SCHEMA_OBJECT_KEYS) {
|
|
if (next[key]) {
|
|
next[key] = sanitizeIntegerConstraints(next[key]);
|
|
}
|
|
}
|
|
|
|
for (const key of SCHEMA_ARRAY_KEYS) {
|
|
if (Array.isArray(next[key])) {
|
|
next[key] = next[key].map(sanitizeIntegerConstraints);
|
|
}
|
|
}
|
|
|
|
for (const key of SCHEMA_RECORD_KEYS) {
|
|
if (next[key] && typeof next[key] === 'object') {
|
|
const mapped = {};
|
|
for (const [entryKey, entryValue] of Object.entries(next[key])) {
|
|
mapped[entryKey] = sanitizeIntegerConstraints(entryValue);
|
|
}
|
|
next[key] = mapped;
|
|
}
|
|
}
|
|
|
|
if (next.items) {
|
|
next.items = sanitizeIntegerConstraints(next.items);
|
|
}
|
|
|
|
return next;
|
|
};
|
|
|
|
const buildSafeSchema = (schema) => {
|
|
const baseSchema = zodSchema(schema);
|
|
if (!baseSchema || typeof baseSchema !== 'object') {
|
|
return baseSchema;
|
|
}
|
|
|
|
if (!baseSchema.jsonSchema) {
|
|
return baseSchema;
|
|
}
|
|
|
|
const sanitizedSchema = sanitizeIntegerConstraints(baseSchema.jsonSchema);
|
|
|
|
if (typeof jsonSchemaHelper === 'function') {
|
|
return jsonSchemaHelper(sanitizedSchema, { validate: baseSchema.validate });
|
|
}
|
|
|
|
return { ...baseSchema, jsonSchema: sanitizedSchema };
|
|
};
|
|
|
|
/**
|
|
* Base class for all AI providers
|
|
*/
|
|
export class BaseAIProvider {
|
|
constructor() {
|
|
if (this.constructor === BaseAIProvider) {
|
|
throw new Error('BaseAIProvider cannot be instantiated directly');
|
|
}
|
|
|
|
// Each provider must set their name
|
|
this.name = this.constructor.name;
|
|
|
|
// Cache proxy agent to avoid creating multiple instances
|
|
this._proxyAgent = null;
|
|
|
|
/**
|
|
* Whether this provider needs explicit schema in JSON mode
|
|
* Can be overridden by subclasses
|
|
* @type {boolean}
|
|
*/
|
|
this.needsExplicitJsonSchema = false;
|
|
|
|
/**
|
|
* Whether this provider supports temperature parameter
|
|
* Can be overridden by subclasses
|
|
* @type {boolean}
|
|
*/
|
|
this.supportsTemperature = true;
|
|
}
|
|
|
|
/**
|
|
* Validates authentication parameters - can be overridden by providers
|
|
* @param {object} params - Parameters to validate
|
|
*/
|
|
validateAuth(params) {
|
|
// Default: require API key (most providers need this)
|
|
if (!params.apiKey) {
|
|
throw new Error(`${this.name} API key is required`);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Creates a custom fetch function with proxy support.
|
|
* Only enables proxy when TASKMASTER_ENABLE_PROXY environment variable is set to 'true'
|
|
* or enableProxy is set to true in config.json.
|
|
* Automatically reads http_proxy/https_proxy environment variables when enabled.
|
|
* @returns {Function} Custom fetch function with proxy support, or undefined if proxy is disabled
|
|
*/
|
|
createProxyFetch() {
|
|
// Cache project root to avoid repeated lookups
|
|
if (!this._projectRoot) {
|
|
this._projectRoot = findProjectRoot();
|
|
}
|
|
const projectRoot = this._projectRoot;
|
|
|
|
if (!isProxyEnabled(null, projectRoot)) {
|
|
// Return undefined to use default fetch without proxy
|
|
return undefined;
|
|
}
|
|
|
|
// Proxy is enabled, create and return proxy fetch
|
|
if (!this._proxyAgent) {
|
|
this._proxyAgent = new EnvHttpProxyAgent();
|
|
}
|
|
return (url, options = {}) => {
|
|
return fetch(url, {
|
|
...options,
|
|
dispatcher: this._proxyAgent
|
|
});
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Validates common parameters across all methods
|
|
* @param {object} params - Parameters to validate
|
|
*/
|
|
validateParams(params) {
|
|
// Validate authentication (can be overridden by providers)
|
|
this.validateAuth(params);
|
|
|
|
// Validate required model ID
|
|
if (!params.modelId) {
|
|
throw new Error(`${this.name} Model ID is required`);
|
|
}
|
|
|
|
// Validate optional parameters
|
|
this.validateOptionalParams(params);
|
|
}
|
|
|
|
/**
|
|
* Validates optional parameters like temperature and maxTokens
|
|
* @param {object} params - Parameters to validate
|
|
*/
|
|
validateOptionalParams(params) {
|
|
if (
|
|
params.temperature !== undefined &&
|
|
(params.temperature < 0 || params.temperature > 1)
|
|
) {
|
|
throw new Error('Temperature must be between 0 and 1');
|
|
}
|
|
if (params.maxTokens !== undefined) {
|
|
const maxTokens = Number(params.maxTokens);
|
|
if (!Number.isFinite(maxTokens) || maxTokens <= 0) {
|
|
throw new Error('maxTokens must be a finite number greater than 0');
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Validates message array structure
|
|
*/
|
|
validateMessages(messages) {
|
|
if (!messages || !Array.isArray(messages) || messages.length === 0) {
|
|
throw new Error('Invalid or empty messages array provided');
|
|
}
|
|
|
|
for (const msg of messages) {
|
|
if (!msg.role || !msg.content) {
|
|
throw new Error(
|
|
'Invalid message format. Each message must have role and content'
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Common error handler
|
|
*/
|
|
handleError(operation, error) {
|
|
const errorMessage = error.message || 'Unknown error occurred';
|
|
log('error', `${this.name} ${operation} failed: ${errorMessage}`, {
|
|
error
|
|
});
|
|
throw new Error(
|
|
`${this.name} API error during ${operation}: ${errorMessage}`
|
|
);
|
|
}
|
|
|
|
/**
|
|
* Creates and returns a client instance for the provider
|
|
* @abstract
|
|
*/
|
|
getClient(params) {
|
|
throw new Error('getClient must be implemented by provider');
|
|
}
|
|
|
|
/**
|
|
* Returns if the API key is required
|
|
* @abstract
|
|
* @returns {boolean} if the API key is required, defaults to true
|
|
*/
|
|
isRequiredApiKey() {
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Returns the required API key environment variable name
|
|
* @abstract
|
|
* @returns {string|null} The environment variable name, or null if no API key is required
|
|
*/
|
|
getRequiredApiKeyName() {
|
|
throw new Error('getRequiredApiKeyName must be implemented by provider');
|
|
}
|
|
|
|
/**
|
|
* Prepares token limit parameter based on model requirements
|
|
* @param {string} modelId - The model ID
|
|
* @param {number} maxTokens - The maximum tokens value
|
|
* @returns {object} Object with either maxTokens or max_completion_tokens
|
|
*/
|
|
prepareTokenParam(modelId, maxTokens) {
|
|
if (maxTokens === undefined) {
|
|
return {};
|
|
}
|
|
|
|
// Ensure maxTokens is an integer
|
|
const tokenValue = Math.floor(Number(maxTokens));
|
|
|
|
return { maxOutputTokens: tokenValue };
|
|
}
|
|
|
|
/**
|
|
* Generates text using the provider's model
|
|
*/
|
|
async generateText(params) {
|
|
try {
|
|
this.validateParams(params);
|
|
this.validateMessages(params.messages);
|
|
|
|
log(
|
|
'debug',
|
|
`Generating ${this.name} text with model: ${params.modelId}`
|
|
);
|
|
|
|
const client = await this.getClient(params);
|
|
|
|
// Get Sentry telemetry config with function ID and metadata for better tracing
|
|
// Format: provider.model.command.method
|
|
const commandName = params.commandName || 'unknown';
|
|
const functionId = `${this.name}.${params.modelId}.${commandName}.generateText`;
|
|
|
|
// Build telemetry metadata for enhanced filtering/grouping in Sentry
|
|
const metadata = {
|
|
command: commandName,
|
|
outputType: params.outputType,
|
|
tag: params.tag,
|
|
projectHash: hashProjectRoot(params.projectRoot),
|
|
userId: params.userId, // Hamster user ID if authenticated
|
|
briefId: params.briefId // Hamster brief ID if connected
|
|
};
|
|
|
|
const telemetryConfig = getAITelemetryConfig(functionId, metadata);
|
|
|
|
const result = await generateText({
|
|
model: client(params.modelId),
|
|
messages: params.messages,
|
|
...this.prepareTokenParam(params.modelId, params.maxTokens),
|
|
...(this.supportsTemperature && params.temperature !== undefined
|
|
? { temperature: params.temperature }
|
|
: {}),
|
|
...(telemetryConfig && { experimental_telemetry: telemetryConfig })
|
|
});
|
|
|
|
log(
|
|
'debug',
|
|
`${this.name} generateText completed successfully for model: ${params.modelId}`
|
|
);
|
|
|
|
const inputTokens =
|
|
result.usage?.inputTokens ?? result.usage?.promptTokens ?? 0;
|
|
const outputTokens =
|
|
result.usage?.outputTokens ?? result.usage?.completionTokens ?? 0;
|
|
const totalTokens =
|
|
result.usage?.totalTokens ?? inputTokens + outputTokens;
|
|
|
|
return {
|
|
text: result.text,
|
|
usage: {
|
|
inputTokens,
|
|
outputTokens,
|
|
totalTokens
|
|
}
|
|
};
|
|
} catch (error) {
|
|
this.handleError('text generation', error);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Streams text using the provider's model
|
|
*/
|
|
async streamText(params) {
|
|
try {
|
|
this.validateParams(params);
|
|
this.validateMessages(params.messages);
|
|
|
|
log('debug', `Streaming ${this.name} text with model: ${params.modelId}`);
|
|
|
|
const client = await this.getClient(params);
|
|
|
|
// Get Sentry telemetry config with function ID and metadata for better tracing
|
|
// Format: provider.model.command.method
|
|
const commandName = params.commandName || 'unknown';
|
|
const functionId = `${this.name}.${params.modelId}.${commandName}.streamText`;
|
|
|
|
// Build telemetry metadata for enhanced filtering/grouping in Sentry
|
|
const metadata = {
|
|
command: commandName,
|
|
outputType: params.outputType,
|
|
tag: params.tag,
|
|
projectHash: hashProjectRoot(params.projectRoot),
|
|
userId: params.userId, // Hamster user ID if authenticated
|
|
briefId: params.briefId // Hamster brief ID if connected
|
|
};
|
|
|
|
const telemetryConfig = getAITelemetryConfig(functionId, metadata);
|
|
|
|
const stream = await streamText({
|
|
model: client(params.modelId),
|
|
messages: params.messages,
|
|
...this.prepareTokenParam(params.modelId, params.maxTokens),
|
|
...(this.supportsTemperature && params.temperature !== undefined
|
|
? { temperature: params.temperature }
|
|
: {}),
|
|
...(telemetryConfig && { experimental_telemetry: telemetryConfig }),
|
|
...(params.experimental_transform && {
|
|
experimental_transform: params.experimental_transform
|
|
})
|
|
});
|
|
|
|
log(
|
|
'debug',
|
|
`${this.name} streamText initiated successfully for model: ${params.modelId}`
|
|
);
|
|
|
|
return stream;
|
|
} catch (error) {
|
|
this.handleError('text streaming', error);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Streams a structured object using the provider's model
|
|
*/
|
|
async streamObject(params) {
|
|
try {
|
|
this.validateParams(params);
|
|
this.validateMessages(params.messages);
|
|
|
|
if (!params.schema) {
|
|
throw new Error('Schema is required for object streaming');
|
|
}
|
|
|
|
log(
|
|
'debug',
|
|
`Streaming ${this.name} object with model: ${params.modelId}`
|
|
);
|
|
|
|
const client = await this.getClient(params);
|
|
|
|
// Get Sentry telemetry config with function ID and metadata for better tracing
|
|
// Format: provider.model.command.method
|
|
const commandName = params.commandName || 'unknown';
|
|
const functionId = `${this.name}.${params.modelId}.${commandName}.streamObject`;
|
|
|
|
// Build telemetry metadata for enhanced filtering/grouping in Sentry
|
|
const metadata = {
|
|
command: commandName,
|
|
outputType: params.outputType,
|
|
tag: params.tag,
|
|
projectHash: hashProjectRoot(params.projectRoot),
|
|
userId: params.userId, // Hamster user ID if authenticated
|
|
briefId: params.briefId // Hamster brief ID if connected
|
|
};
|
|
|
|
const telemetryConfig = getAITelemetryConfig(functionId, metadata);
|
|
|
|
const schema = buildSafeSchema(params.schema);
|
|
|
|
const result = await streamObject({
|
|
model: client(params.modelId),
|
|
messages: params.messages,
|
|
schema,
|
|
mode: params.mode || 'auto',
|
|
maxOutputTokens: params.maxTokens,
|
|
...(this.supportsTemperature && params.temperature !== undefined
|
|
? { temperature: params.temperature }
|
|
: {}),
|
|
...(telemetryConfig && { experimental_telemetry: telemetryConfig })
|
|
});
|
|
|
|
log(
|
|
'debug',
|
|
`${this.name} streamObject initiated successfully for model: ${params.modelId}`
|
|
);
|
|
|
|
// Return the stream result directly
|
|
// The stream result contains partialObjectStream and other properties
|
|
return result;
|
|
} catch (error) {
|
|
this.handleError('object streaming', error);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generates a structured object using the provider's model
|
|
*/
|
|
async generateObject(params) {
|
|
try {
|
|
this.validateParams(params);
|
|
this.validateMessages(params.messages);
|
|
|
|
if (!params.schema) {
|
|
throw new Error('Schema is required for object generation');
|
|
}
|
|
if (!params.objectName) {
|
|
throw new Error('Object name is required for object generation');
|
|
}
|
|
|
|
log(
|
|
'debug',
|
|
`Generating ${this.name} object ('${params.objectName}') with model: ${params.modelId}`
|
|
);
|
|
|
|
const client = await this.getClient(params);
|
|
|
|
// Get Sentry telemetry config with function ID and metadata for better tracing
|
|
// Format: provider.model.command.method.objectName
|
|
const commandName = params.commandName || 'unknown';
|
|
const functionId = `${this.name}.${params.modelId}.${commandName}.generateObject.${params.objectName}`;
|
|
|
|
// Build telemetry metadata for enhanced filtering/grouping in Sentry
|
|
const metadata = {
|
|
command: commandName,
|
|
outputType: params.outputType,
|
|
tag: params.tag,
|
|
projectHash: hashProjectRoot(params.projectRoot),
|
|
userId: params.userId, // Hamster user ID if authenticated
|
|
briefId: params.briefId // Hamster brief ID if connected
|
|
};
|
|
|
|
const telemetryConfig = getAITelemetryConfig(functionId, metadata);
|
|
|
|
const schema = buildSafeSchema(params.schema);
|
|
|
|
const result = await generateObject({
|
|
model: client(params.modelId),
|
|
messages: params.messages,
|
|
schema,
|
|
mode: this.needsExplicitJsonSchema ? 'json' : 'auto',
|
|
schemaName: params.objectName,
|
|
schemaDescription: `Generate a valid JSON object for ${params.objectName}`,
|
|
maxTokens: params.maxTokens,
|
|
...(this.supportsTemperature && params.temperature !== undefined
|
|
? { temperature: params.temperature }
|
|
: {}),
|
|
...(telemetryConfig && { experimental_telemetry: telemetryConfig })
|
|
});
|
|
|
|
log(
|
|
'debug',
|
|
`${this.name} generateObject completed successfully for model: ${params.modelId}`
|
|
);
|
|
|
|
const inputTokens =
|
|
result.usage?.inputTokens ?? result.usage?.promptTokens ?? 0;
|
|
const outputTokens =
|
|
result.usage?.outputTokens ?? result.usage?.completionTokens ?? 0;
|
|
const totalTokens =
|
|
result.usage?.totalTokens ?? inputTokens + outputTokens;
|
|
|
|
return {
|
|
object: result.object,
|
|
usage: {
|
|
inputTokens,
|
|
outputTokens,
|
|
totalTokens
|
|
}
|
|
};
|
|
} catch (error) {
|
|
// Check if this is a JSON parsing error that we can potentially fix
|
|
if (
|
|
NoObjectGeneratedError.isInstance(error) &&
|
|
error.cause instanceof JSONParseError &&
|
|
error.cause.text
|
|
) {
|
|
log(
|
|
'warn',
|
|
`${this.name} generated malformed JSON, attempting to repair...`
|
|
);
|
|
|
|
try {
|
|
// Use jsonrepair to fix the malformed JSON
|
|
const repairedJson = jsonrepair(error.cause.text);
|
|
const parsed = JSON.parse(repairedJson);
|
|
|
|
log('info', `Successfully repaired ${this.name} JSON output`);
|
|
|
|
// Return in the expected format
|
|
return {
|
|
object: parsed,
|
|
usage: {
|
|
// Extract usage information from the error if available
|
|
inputTokens:
|
|
error.usage?.promptTokens || error.usage?.inputTokens || 0,
|
|
outputTokens:
|
|
error.usage?.completionTokens || error.usage?.outputTokens || 0,
|
|
totalTokens: error.usage?.totalTokens || 0
|
|
}
|
|
};
|
|
} catch (repairError) {
|
|
log(
|
|
'error',
|
|
`Failed to repair ${this.name} JSON: ${repairError.message}`
|
|
);
|
|
// Fall through to handleError with original error
|
|
}
|
|
}
|
|
|
|
this.handleError('object generation', error);
|
|
}
|
|
}
|
|
}
|