diff --git a/.changeset/forty-plums-stay.md b/.changeset/forty-plums-stay.md new file mode 100644 index 00000000..d49e0653 --- /dev/null +++ b/.changeset/forty-plums-stay.md @@ -0,0 +1,5 @@ +--- +'task-master-ai': minor +--- + +.taskmasterconfig now supports a baseUrl field per model role (main, research, fallback), allowing endpoint overrides for any provider. diff --git a/docs/configuration.md b/docs/configuration.md index f1e57560..615e184f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -15,13 +15,15 @@ Taskmaster uses two primary methods for configuration: "provider": "anthropic", "modelId": "claude-3-7-sonnet-20250219", "maxTokens": 64000, - "temperature": 0.2 + "temperature": 0.2, + "baseUrl": "https://api.anthropic.com/v1" }, "research": { "provider": "perplexity", "modelId": "sonar-pro", "maxTokens": 8700, - "temperature": 0.1 + "temperature": 0.1, + "baseUrl": "https://api.perplexity.ai/v1" }, "fallback": { "provider": "anthropic", @@ -56,8 +58,9 @@ Taskmaster uses two primary methods for configuration: - `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (also requires `AZURE_OPENAI_ENDPOINT`). - `OPENROUTER_API_KEY`: Your OpenRouter API key. - `XAI_API_KEY`: Your X-AI API key. - - **Optional Endpoint Overrides (in .taskmasterconfig):** - - `AZURE_OPENAI_ENDPOINT`: Required if using Azure OpenAI key. + - **Optional Endpoint Overrides:** + - **Per-role `baseUrl` in `.taskmasterconfig`:** You can add a `baseUrl` property to any model role (`main`, `research`, `fallback`) to override the default API endpoint for that provider. If omitted, the provider's standard endpoint is used. + - `AZURE_OPENAI_ENDPOINT`: Required if using Azure OpenAI key (can also be set as `baseUrl` for the Azure model role). - `OLLAMA_BASE_URL`: Override the default Ollama API URL (Default: `http://localhost:11434/api`). **Important:** Settings like model ID selections (`main`, `research`, `fallback`), `maxTokens`, `temperature`, `logLevel`, `defaultSubtasks`, `defaultPriority`, and `projectName` are **managed in `.taskmasterconfig`**, not environment variables. diff --git a/package-lock.json b/package-lock.json index ff03b4e2..2a437a53 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "task-master-ai", - "version": "0.12", + "version": "0.13.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "task-master-ai", - "version": "0.12", + "version": "0.13.2", "license": "MIT WITH Commons-Clause", "dependencies": { "@ai-sdk/anthropic": "^1.2.10", diff --git a/scripts/modules/ai-services-unified.js b/scripts/modules/ai-services-unified.js index fead4ad3..da958986 100644 --- a/scripts/modules/ai-services-unified.js +++ b/scripts/modules/ai-services-unified.js @@ -14,7 +14,8 @@ import { getResearchModelId, getFallbackProvider, getFallbackModelId, - getParametersForRole + getParametersForRole, + getBaseUrlForRole } from './config-manager.js'; import { log, resolveEnvVariable, findProjectRoot } from './utils.js'; @@ -284,7 +285,13 @@ async function _unifiedServiceRunner(serviceType, params) { 'AI service call failed for all configured roles.'; for (const currentRole of sequence) { - let providerName, modelId, apiKey, roleParams, providerFnSet, providerApiFn; + let providerName, + modelId, + apiKey, + roleParams, + providerFnSet, + providerApiFn, + baseUrl; try { log('info', `New AI service call with role: ${currentRole}`); @@ -325,6 +332,7 @@ async function _unifiedServiceRunner(serviceType, params) { // Pass effectiveProjectRoot to getParametersForRole roleParams = getParametersForRole(currentRole, effectiveProjectRoot); + baseUrl = getBaseUrlForRole(currentRole, effectiveProjectRoot); // 2. Get Provider Function Set providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()]; @@ -401,6 +409,7 @@ async function _unifiedServiceRunner(serviceType, params) { maxTokens: roleParams.maxTokens, temperature: roleParams.temperature, messages, + baseUrl, ...(serviceType === 'generateObject' && { schema, objectName }), ...restApiParams }; diff --git a/scripts/modules/config-manager.js b/scripts/modules/config-manager.js index e9302d08..a4ed94e5 100644 --- a/scripts/modules/config-manager.js +++ b/scripts/modules/config-manager.js @@ -677,6 +677,13 @@ function getAllProviders() { return Object.keys(MODEL_MAP || {}); } +function getBaseUrlForRole(role, explicitRoot = null) { + const roleConfig = getModelConfigForRole(role, explicitRoot); + return roleConfig && typeof roleConfig.baseUrl === 'string' + ? roleConfig.baseUrl + : undefined; +} + export { // Core config access getConfig, @@ -704,6 +711,7 @@ export { getFallbackModelId, getFallbackMaxTokens, getFallbackTemperature, + getBaseUrlForRole, // Global setting getters (No env var overrides) getLogLevel, diff --git a/src/ai-providers/anthropic.js b/src/ai-providers/anthropic.js index 1fa36f3d..27602757 100644 --- a/src/ai-providers/anthropic.js +++ b/src/ai-providers/anthropic.js @@ -5,7 +5,7 @@ * using the Vercel AI SDK. */ import { createAnthropic } from '@ai-sdk/anthropic'; -import { generateText, streamText, generateObject, streamObject } from 'ai'; +import { generateText, streamText, generateObject } from 'ai'; import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible // TODO: Implement standardized functions for generateText, streamText, generateObject @@ -17,7 +17,7 @@ import { log } from '../../scripts/modules/utils.js'; // Assuming utils is acces // Remove the global variable and caching logic // let anthropicClient; -function getClient(apiKey) { +function getClient(apiKey, baseUrl) { if (!apiKey) { // In a real scenario, this would use the config resolver. // Throwing error here if key isn't passed for simplicity. @@ -30,14 +30,12 @@ function getClient(apiKey) { // Create and return a new instance directly with standard version header return createAnthropic({ apiKey: apiKey, - baseURL: 'https://api.anthropic.com/v1', + ...(baseUrl && { baseURL: baseUrl }), // Use standard version header instead of beta headers: { 'anthropic-beta': 'output-128k-2025-02-19' } }); - // } - // return anthropicClient; } // --- Standardized Service Function Implementations --- @@ -51,6 +49,7 @@ function getClient(apiKey) { * @param {Array} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]). * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - The base URL for the Anthropic API. * @returns {Promise} The generated text content. * @throws {Error} If the API call fails. */ @@ -59,11 +58,12 @@ export async function generateAnthropicText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Generating Anthropic text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const result = await generateText({ model: client(modelId), messages: messages, @@ -93,6 +93,7 @@ export async function generateAnthropicText({ * @param {Array} params.messages - The messages array. * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - The base URL for the Anthropic API. * @returns {Promise} The full stream result object from the Vercel AI SDK. * @throws {Error} If the API call fails to initiate the stream. */ @@ -101,20 +102,20 @@ export async function streamAnthropicText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Streaming Anthropic text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); - // --- DEBUG LOGGING --- >> log( 'debug', '[streamAnthropicText] Parameters received by streamText:', JSON.stringify( { - modelId: modelId, // Log modelId being used - messages: messages, // Log the messages array + modelId: modelId, + messages: messages, maxTokens: maxTokens, temperature: temperature }, @@ -122,25 +123,19 @@ export async function streamAnthropicText({ 2 ) ); - // --- << DEBUG LOGGING --- const stream = await streamText({ model: client(modelId), messages: messages, maxTokens: maxTokens, temperature: temperature - // Beta header moved to client initialization // TODO: Add other relevant parameters }); // *** RETURN THE FULL STREAM OBJECT, NOT JUST stream.textStream *** return stream; } catch (error) { - log( - 'error', - `Anthropic streamText failed: ${error.message}`, - error.stack // Log stack trace for more details - ); + log('error', `Anthropic streamText failed: ${error.message}`, error.stack); throw error; } } @@ -160,6 +155,7 @@ export async function streamAnthropicText({ * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.maxRetries] - Max retries for validation/generation. + * @param {string} [params.baseUrl] - The base URL for the Anthropic API. * @returns {Promise} The generated object matching the schema. * @throws {Error} If generation or validation fails. */ @@ -171,24 +167,22 @@ export async function generateAnthropicObject({ objectName = 'generated_object', maxTokens, temperature, - maxRetries = 3 + maxRetries = 3, + baseUrl }) { log( 'debug', `Generating Anthropic object ('${objectName}') with model: ${modelId}` ); try { - const client = getClient(apiKey); - - // Log basic debug info + const client = getClient(apiKey, baseUrl); log( 'debug', `Using maxTokens: ${maxTokens}, temperature: ${temperature}, model: ${modelId}` ); - const result = await generateObject({ model: client(modelId), - mode: 'tool', // Anthropic generally uses 'tool' mode for structured output + mode: 'tool', schema: schema, messages: messages, tool: { @@ -199,14 +193,12 @@ export async function generateAnthropicObject({ temperature: temperature, maxRetries: maxRetries }); - log( 'debug', `Anthropic generateObject result received. Tokens: ${result.usage.completionTokens}/${result.usage.promptTokens}` ); return result.object; } catch (error) { - // Simple error logging log( 'error', `Anthropic generateObject ('${objectName}') failed: ${error.message}` diff --git a/src/ai-providers/google.js b/src/ai-providers/google.js index 037f9a3c..7428816b 100644 --- a/src/ai-providers/google.js +++ b/src/ai-providers/google.js @@ -12,6 +12,16 @@ import { log } from '../../scripts/modules/utils.js'; // Import logging utility const DEFAULT_MODEL = 'gemini-2.0-pro'; // Or a suitable default const DEFAULT_TEMPERATURE = 0.2; // Or a suitable default +function getClient(apiKey, baseUrl) { + if (!apiKey) { + throw new Error('Google API key is required.'); + } + return createGoogleGenerativeAI({ + apiKey: apiKey, + ...(baseUrl && { baseURL: baseUrl }) + }); +} + /** * Generates text using a Google AI model. * @@ -29,7 +39,8 @@ async function generateGoogleText({ modelId = DEFAULT_MODEL, temperature = DEFAULT_TEMPERATURE, messages, - maxTokens // Note: Vercel SDK might handle this differently, needs verification + maxTokens, + baseUrl }) { if (!apiKey) { throw new Error('Google API key is required.'); @@ -37,28 +48,21 @@ async function generateGoogleText({ log('info', `Generating text with Google model: ${modelId}`); try { - // const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation - const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation - // const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval - const model = googleProvider(modelId); // Correct model retrieval - - // Construct payload suitable for Vercel SDK's generateText - // Note: The exact structure might depend on how messages are passed + const googleProvider = getClient(apiKey, baseUrl); + const model = googleProvider(modelId); const result = await generateText({ - model, // Pass the model instance - messages, // Pass the messages array directly + model, + messages, temperature, - maxOutputTokens: maxTokens // Map to correct Vercel SDK param if available + maxOutputTokens: maxTokens }); - - // Assuming result structure provides text directly or within a property - return result.text; // Adjust based on actual SDK response + return result.text; } catch (error) { log( 'error', `Error generating text with Google (${modelId}): ${error.message}` ); - throw error; // Re-throw for unified service handler + throw error; } } @@ -79,7 +83,8 @@ async function streamGoogleText({ modelId = DEFAULT_MODEL, temperature = DEFAULT_TEMPERATURE, messages, - maxTokens + maxTokens, + baseUrl }) { if (!apiKey) { throw new Error('Google API key is required.'); @@ -87,19 +92,15 @@ async function streamGoogleText({ log('info', `Streaming text with Google model: ${modelId}`); try { - // const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation - const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation - // const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval - const model = googleProvider(modelId); // Correct model retrieval - + const googleProvider = getClient(apiKey, baseUrl); + const model = googleProvider(modelId); const stream = await streamText({ - model, // Pass the model instance + model, messages, temperature, maxOutputTokens: maxTokens }); - - return stream; // Return the stream directly + return stream; } catch (error) { log( 'error', @@ -130,7 +131,8 @@ async function generateGoogleObject({ messages, schema, objectName, // Note: Vercel SDK might use this differently or not at all - maxTokens + maxTokens, + baseUrl }) { if (!apiKey) { throw new Error('Google API key is required.'); @@ -138,23 +140,16 @@ async function generateGoogleObject({ log('info', `Generating object with Google model: ${modelId}`); try { - // const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation - const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation - // const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval - const model = googleProvider(modelId); // Correct model retrieval - + const googleProvider = getClient(apiKey, baseUrl); + const model = googleProvider(modelId); const { object } = await generateObject({ - model, // Pass the model instance + model, schema, messages, temperature, maxOutputTokens: maxTokens - // Note: 'objectName' or 'mode' might not be directly applicable here - // depending on how `@ai-sdk/google` handles `generateObject`. - // Check SDK docs if specific tool calling/JSON mode needs explicit setup. }); - - return object; // Return the parsed object + return object; } catch (error) { log( 'error', diff --git a/src/ai-providers/openai.js b/src/ai-providers/openai.js index ce34e957..3a0f2090 100644 --- a/src/ai-providers/openai.js +++ b/src/ai-providers/openai.js @@ -1,16 +1,26 @@ -import { createOpenAI, openai } from '@ai-sdk/openai'; // Using openai provider from Vercel AI SDK -import { generateText, streamText, generateObject } from 'ai'; // Import necessary functions from 'ai' +import { createOpenAI } from '@ai-sdk/openai'; // Using openai provider from Vercel AI SDK +import { generateObject } from 'ai'; // Import necessary functions from 'ai' import { log } from '../../scripts/modules/utils.js'; +function getClient(apiKey, baseUrl) { + if (!apiKey) { + throw new Error('OpenAI API key is required.'); + } + return createOpenAI({ + apiKey: apiKey, + ...(baseUrl && { baseURL: baseUrl }) + }); +} + /** * Generates text using OpenAI models via Vercel AI SDK. * - * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature. + * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature, baseUrl. * @returns {Promise} The generated text content. * @throws {Error} If API call fails. */ export async function generateOpenAIText(params) { - const { apiKey, modelId, messages, maxTokens, temperature } = params; + const { apiKey, modelId, messages, maxTokens, temperature, baseUrl } = params; log('debug', `generateOpenAIText called with model: ${modelId}`); if (!apiKey) { @@ -23,18 +33,15 @@ export async function generateOpenAIText(params) { throw new Error('Invalid or empty messages array provided for OpenAI.'); } - const openaiClient = createOpenAI({ apiKey }); + const openaiClient = getClient(apiKey, baseUrl); try { const result = await openaiClient.chat(messages, { - // Updated: Use openaiClient.chat directly model: modelId, max_tokens: maxTokens, temperature }); - // Adjust based on actual Vercel SDK response structure for openaiClient.chat - // This might need refinement based on testing the SDK's output. const textContent = result?.choices?.[0]?.message?.content?.trim(); if (!textContent) { @@ -65,12 +72,12 @@ export async function generateOpenAIText(params) { /** * Streams text using OpenAI models via Vercel AI SDK. * - * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature. + * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature, baseUrl. * @returns {Promise} A readable stream of text deltas. * @throws {Error} If API call fails. */ export async function streamOpenAIText(params) { - const { apiKey, modelId, messages, maxTokens, temperature } = params; + const { apiKey, modelId, messages, maxTokens, temperature, baseUrl } = params; log('debug', `streamOpenAIText called with model: ${modelId}`); if (!apiKey) { @@ -85,12 +92,10 @@ export async function streamOpenAIText(params) { ); } - const openaiClient = createOpenAI({ apiKey }); + const openaiClient = getClient(apiKey, baseUrl); try { - // Use the streamText function from Vercel AI SDK core const stream = await openaiClient.chat.stream(messages, { - // Updated: Use openaiClient.chat.stream model: modelId, max_tokens: maxTokens, temperature @@ -100,7 +105,6 @@ export async function streamOpenAIText(params) { 'debug', `OpenAI streamText initiated successfully for model: ${modelId}` ); - // The Vercel SDK's streamText should directly return the stream object return stream; } catch (error) { log( @@ -117,7 +121,7 @@ export async function streamOpenAIText(params) { /** * Generates structured objects using OpenAI models via Vercel AI SDK. * - * @param {object} params - Parameters including apiKey, modelId, messages, schema, objectName, maxTokens, temperature. + * @param {object} params - Parameters including apiKey, modelId, messages, schema, objectName, maxTokens, temperature, baseUrl. * @returns {Promise} The generated object matching the schema. * @throws {Error} If API call fails or object generation fails. */ @@ -129,7 +133,8 @@ export async function generateOpenAIObject(params) { schema, objectName, maxTokens, - temperature + temperature, + baseUrl } = params; log( 'debug', @@ -145,10 +150,9 @@ export async function generateOpenAIObject(params) { if (!objectName) throw new Error('Object name is required for OpenAI object generation.'); - const openaiClient = createOpenAI({ apiKey }); + const openaiClient = getClient(apiKey, baseUrl); try { - // Use the imported generateObject function from 'ai' package const result = await generateObject({ model: openaiClient(modelId), schema: schema, diff --git a/src/ai-providers/openrouter.js b/src/ai-providers/openrouter.js index 594d208c..f842cbf2 100644 --- a/src/ai-providers/openrouter.js +++ b/src/ai-providers/openrouter.js @@ -2,6 +2,14 @@ import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { generateText, streamText, generateObject } from 'ai'; import { log } from '../../scripts/modules/utils.js'; // Assuming utils.js is in scripts/modules +function getClient(apiKey, baseUrl) { + if (!apiKey) throw new Error('OpenRouter API key is required.'); + return createOpenRouter({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }) + }); +} + /** * Generates text using an OpenRouter chat model. * @@ -11,6 +19,7 @@ import { log } from '../../scripts/modules/utils.js'; // Assuming utils.js is in * @param {Array} params.messages - Array of message objects (system, user, assistant). * @param {number} [params.maxTokens] - Maximum tokens to generate. * @param {number} [params.temperature] - Sampling temperature. + * @param {string} [params.baseUrl] - Base URL for the OpenRouter API. * @returns {Promise} The generated text content. * @throws {Error} If the API call fails. */ @@ -20,6 +29,7 @@ async function generateOpenRouterText({ messages, maxTokens, temperature, + baseUrl, ...rest // Capture any other Vercel AI SDK compatible parameters }) { if (!apiKey) throw new Error('OpenRouter API key is required.'); @@ -28,7 +38,7 @@ async function generateOpenRouterText({ throw new Error('Messages array cannot be empty.'); try { - const openrouter = createOpenRouter({ apiKey }); + const openrouter = getClient(apiKey, baseUrl); const model = openrouter.chat(modelId); // Assuming chat model const { text } = await generateText({ @@ -58,6 +68,7 @@ async function generateOpenRouterText({ * @param {Array} params.messages - Array of message objects (system, user, assistant). * @param {number} [params.maxTokens] - Maximum tokens to generate. * @param {number} [params.temperature] - Sampling temperature. + * @param {string} [params.baseUrl] - Base URL for the OpenRouter API. * @returns {Promise>} A readable stream of text deltas. * @throws {Error} If the API call fails. */ @@ -67,6 +78,7 @@ async function streamOpenRouterText({ messages, maxTokens, temperature, + baseUrl, ...rest }) { if (!apiKey) throw new Error('OpenRouter API key is required.'); @@ -75,7 +87,7 @@ async function streamOpenRouterText({ throw new Error('Messages array cannot be empty.'); try { - const openrouter = createOpenRouter({ apiKey }); + const openrouter = getClient(apiKey, baseUrl); const model = openrouter.chat(modelId); // Directly return the stream from the Vercel AI SDK function @@ -108,6 +120,7 @@ async function streamOpenRouterText({ * @param {number} [params.maxRetries=3] - Max retries for object generation. * @param {number} [params.maxTokens] - Maximum tokens. * @param {number} [params.temperature] - Temperature. + * @param {string} [params.baseUrl] - Base URL for the OpenRouter API. * @returns {Promise} The generated object matching the schema. * @throws {Error} If the API call fails or validation fails. */ @@ -120,6 +133,7 @@ async function generateOpenRouterObject({ maxRetries = 3, maxTokens, temperature, + baseUrl, ...rest }) { if (!apiKey) throw new Error('OpenRouter API key is required.'); @@ -129,7 +143,7 @@ async function generateOpenRouterObject({ throw new Error('Messages array cannot be empty.'); try { - const openrouter = createOpenRouter({ apiKey }); + const openrouter = getClient(apiKey, baseUrl); const model = openrouter.chat(modelId); const { object } = await generateObject({ diff --git a/src/ai-providers/perplexity.js b/src/ai-providers/perplexity.js index e8982d6f..7255753d 100644 --- a/src/ai-providers/perplexity.js +++ b/src/ai-providers/perplexity.js @@ -10,13 +10,13 @@ import { log } from '../../scripts/modules/utils.js'; // --- Client Instantiation --- // Similar to Anthropic, this expects the resolved API key to be passed in. -function getClient(apiKey) { +function getClient(apiKey, baseUrl) { if (!apiKey) { throw new Error('Perplexity API key is required.'); } - // Create and return a new instance directly return createPerplexity({ - apiKey: apiKey + apiKey: apiKey, + ...(baseUrl && { baseURL: baseUrl }) }); } @@ -31,6 +31,7 @@ function getClient(apiKey) { * @param {Array} params.messages - The messages array. * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - Base URL for the Perplexity API. * @returns {Promise} The generated text content. * @throws {Error} If the API call fails. */ @@ -39,11 +40,12 @@ export async function generatePerplexityText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Generating Perplexity text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const result = await generateText({ model: client(modelId), messages: messages, @@ -70,6 +72,7 @@ export async function generatePerplexityText({ * @param {Array} params.messages - The messages array. * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - Base URL for the Perplexity API. * @returns {Promise} The full stream result object from the Vercel AI SDK. * @throws {Error} If the API call fails to initiate the stream. */ @@ -78,11 +81,12 @@ export async function streamPerplexityText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Streaming Perplexity text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const stream = await streamText({ model: client(modelId), messages: messages, @@ -112,6 +116,7 @@ export async function streamPerplexityText({ * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.maxRetries] - Max retries for validation/generation. + * @param {string} [params.baseUrl] - Base URL for the Perplexity API. * @returns {Promise} The generated object matching the schema. * @throws {Error} If generation or validation fails or is unsupported. */ @@ -123,7 +128,8 @@ export async function generatePerplexityObject({ objectName = 'generated_object', maxTokens, temperature, - maxRetries = 1 // Lower retries as support might be limited + maxRetries = 1, + baseUrl }) { log( 'debug', @@ -134,8 +140,7 @@ export async function generatePerplexityObject({ 'generateObject support for Perplexity might be limited or experimental.' ); try { - const client = getClient(apiKey); - // Attempt using generateObject, but be prepared for potential issues + const client = getClient(apiKey, baseUrl); const result = await generateObject({ model: client(modelId), schema: schema, diff --git a/src/ai-providers/xai.js b/src/ai-providers/xai.js index 1886e787..fa2cc954 100644 --- a/src/ai-providers/xai.js +++ b/src/ai-providers/xai.js @@ -9,14 +9,13 @@ import { generateText, streamText, generateObject } from 'ai'; // Only import wh import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible // --- Client Instantiation --- -function getClient(apiKey) { +function getClient(apiKey, baseUrl) { if (!apiKey) { throw new Error('xAI API key is required.'); } - // Create and return a new instance directly return createXai({ - apiKey: apiKey - // Add baseURL or other options if needed later + apiKey: apiKey, + ...(baseUrl && { baseURL: baseUrl }) }); } @@ -31,6 +30,7 @@ function getClient(apiKey) { * @param {Array} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]). * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - The base URL for the xAI API. * @returns {Promise} The generated text content. * @throws {Error} If the API call fails. */ @@ -39,13 +39,14 @@ export async function generateXaiText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Generating xAI text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const result = await generateText({ - model: client(modelId), // Correct model invocation + model: client(modelId), messages: messages, maxTokens: maxTokens, temperature: temperature @@ -70,6 +71,7 @@ export async function generateXaiText({ * @param {Array} params.messages - The messages array. * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. + * @param {string} [params.baseUrl] - The base URL for the xAI API. * @returns {Promise} The full stream result object from the Vercel AI SDK. * @throws {Error} If the API call fails to initiate the stream. */ @@ -78,18 +80,19 @@ export async function streamXaiText({ modelId, messages, maxTokens, - temperature + temperature, + baseUrl }) { log('debug', `Streaming xAI text with model: ${modelId}`); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const stream = await streamText({ - model: client(modelId), // Correct model invocation + model: client(modelId), messages: messages, maxTokens: maxTokens, temperature: temperature }); - return stream; // Return the full stream object + return stream; } catch (error) { log('error', `xAI streamText failed: ${error.message}`, error.stack); throw error; @@ -110,6 +113,7 @@ export async function streamXaiText({ * @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.maxRetries] - Max retries for validation/generation. + * @param {string} [params.baseUrl] - The base URL for the xAI API. * @returns {Promise} The generated object matching the schema. * @throws {Error} If generation or validation fails. */ @@ -121,16 +125,17 @@ export async function generateXaiObject({ objectName = 'generated_xai_object', maxTokens, temperature, - maxRetries = 3 + maxRetries = 3, + baseUrl }) { log( - 'warn', // Log warning as this is likely unsupported + 'warn', `Attempting to generate xAI object ('${objectName}') with model: ${modelId}. This may not be supported by the provider.` ); try { - const client = getClient(apiKey); + const client = getClient(apiKey, baseUrl); const result = await generateObject({ - model: client(modelId), // Correct model invocation + model: client(modelId), // Note: mode might need adjustment if xAI ever supports object generation differently mode: 'tool', schema: schema, @@ -153,6 +158,6 @@ export async function generateXaiObject({ 'error', `xAI generateObject ('${objectName}') failed: ${error.message}. (Likely unsupported by provider)` ); - throw error; // Re-throw the error + throw error; } } diff --git a/tests/unit/ai-services-unified.test.js b/tests/unit/ai-services-unified.test.js index 59e3d32b..4098e75e 100644 --- a/tests/unit/ai-services-unified.test.js +++ b/tests/unit/ai-services-unified.test.js @@ -8,6 +8,7 @@ const mockGetResearchModelId = jest.fn(); const mockGetFallbackProvider = jest.fn(); const mockGetFallbackModelId = jest.fn(); const mockGetParametersForRole = jest.fn(); +const mockGetBaseUrlForRole = jest.fn(); jest.unstable_mockModule('../../scripts/modules/config-manager.js', () => ({ getMainProvider: mockGetMainProvider, @@ -16,7 +17,8 @@ jest.unstable_mockModule('../../scripts/modules/config-manager.js', () => ({ getResearchModelId: mockGetResearchModelId, getFallbackProvider: mockGetFallbackProvider, getFallbackModelId: mockGetFallbackModelId, - getParametersForRole: mockGetParametersForRole + getParametersForRole: mockGetParametersForRole, + getBaseUrlForRole: mockGetBaseUrlForRole })); // Mock AI Provider Modules