feat: implement baseUrls on all ai providers(#521)

This commit is contained in:
Ralph Khreish
2025-05-16 15:34:29 +02:00
committed by GitHub
parent e96734a6cc
commit ed17cb0e0a
12 changed files with 161 additions and 119 deletions

View File

@@ -0,0 +1,5 @@
---
'task-master-ai': minor
---
.taskmasterconfig now supports a baseUrl field per model role (main, research, fallback), allowing endpoint overrides for any provider.

View File

@@ -15,13 +15,15 @@ Taskmaster uses two primary methods for configuration:
"provider": "anthropic", "provider": "anthropic",
"modelId": "claude-3-7-sonnet-20250219", "modelId": "claude-3-7-sonnet-20250219",
"maxTokens": 64000, "maxTokens": 64000,
"temperature": 0.2 "temperature": 0.2,
"baseUrl": "https://api.anthropic.com/v1"
}, },
"research": { "research": {
"provider": "perplexity", "provider": "perplexity",
"modelId": "sonar-pro", "modelId": "sonar-pro",
"maxTokens": 8700, "maxTokens": 8700,
"temperature": 0.1 "temperature": 0.1,
"baseUrl": "https://api.perplexity.ai/v1"
}, },
"fallback": { "fallback": {
"provider": "anthropic", "provider": "anthropic",
@@ -56,8 +58,9 @@ Taskmaster uses two primary methods for configuration:
- `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (also requires `AZURE_OPENAI_ENDPOINT`). - `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (also requires `AZURE_OPENAI_ENDPOINT`).
- `OPENROUTER_API_KEY`: Your OpenRouter API key. - `OPENROUTER_API_KEY`: Your OpenRouter API key.
- `XAI_API_KEY`: Your X-AI API key. - `XAI_API_KEY`: Your X-AI API key.
- **Optional Endpoint Overrides (in .taskmasterconfig):** - **Optional Endpoint Overrides:**
- `AZURE_OPENAI_ENDPOINT`: Required if using Azure OpenAI key. - **Per-role `baseUrl` in `.taskmasterconfig`:** You can add a `baseUrl` property to any model role (`main`, `research`, `fallback`) to override the default API endpoint for that provider. If omitted, the provider's standard endpoint is used.
- `AZURE_OPENAI_ENDPOINT`: Required if using Azure OpenAI key (can also be set as `baseUrl` for the Azure model role).
- `OLLAMA_BASE_URL`: Override the default Ollama API URL (Default: `http://localhost:11434/api`). - `OLLAMA_BASE_URL`: Override the default Ollama API URL (Default: `http://localhost:11434/api`).
**Important:** Settings like model ID selections (`main`, `research`, `fallback`), `maxTokens`, `temperature`, `logLevel`, `defaultSubtasks`, `defaultPriority`, and `projectName` are **managed in `.taskmasterconfig`**, not environment variables. **Important:** Settings like model ID selections (`main`, `research`, `fallback`), `maxTokens`, `temperature`, `logLevel`, `defaultSubtasks`, `defaultPriority`, and `projectName` are **managed in `.taskmasterconfig`**, not environment variables.

4
package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "task-master-ai", "name": "task-master-ai",
"version": "0.12", "version": "0.13.2",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "task-master-ai", "name": "task-master-ai",
"version": "0.12", "version": "0.13.2",
"license": "MIT WITH Commons-Clause", "license": "MIT WITH Commons-Clause",
"dependencies": { "dependencies": {
"@ai-sdk/anthropic": "^1.2.10", "@ai-sdk/anthropic": "^1.2.10",

View File

@@ -14,7 +14,8 @@ import {
getResearchModelId, getResearchModelId,
getFallbackProvider, getFallbackProvider,
getFallbackModelId, getFallbackModelId,
getParametersForRole getParametersForRole,
getBaseUrlForRole
} from './config-manager.js'; } from './config-manager.js';
import { log, resolveEnvVariable, findProjectRoot } from './utils.js'; import { log, resolveEnvVariable, findProjectRoot } from './utils.js';
@@ -284,7 +285,13 @@ async function _unifiedServiceRunner(serviceType, params) {
'AI service call failed for all configured roles.'; 'AI service call failed for all configured roles.';
for (const currentRole of sequence) { for (const currentRole of sequence) {
let providerName, modelId, apiKey, roleParams, providerFnSet, providerApiFn; let providerName,
modelId,
apiKey,
roleParams,
providerFnSet,
providerApiFn,
baseUrl;
try { try {
log('info', `New AI service call with role: ${currentRole}`); log('info', `New AI service call with role: ${currentRole}`);
@@ -325,6 +332,7 @@ async function _unifiedServiceRunner(serviceType, params) {
// Pass effectiveProjectRoot to getParametersForRole // Pass effectiveProjectRoot to getParametersForRole
roleParams = getParametersForRole(currentRole, effectiveProjectRoot); roleParams = getParametersForRole(currentRole, effectiveProjectRoot);
baseUrl = getBaseUrlForRole(currentRole, effectiveProjectRoot);
// 2. Get Provider Function Set // 2. Get Provider Function Set
providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()]; providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()];
@@ -401,6 +409,7 @@ async function _unifiedServiceRunner(serviceType, params) {
maxTokens: roleParams.maxTokens, maxTokens: roleParams.maxTokens,
temperature: roleParams.temperature, temperature: roleParams.temperature,
messages, messages,
baseUrl,
...(serviceType === 'generateObject' && { schema, objectName }), ...(serviceType === 'generateObject' && { schema, objectName }),
...restApiParams ...restApiParams
}; };

View File

@@ -677,6 +677,13 @@ function getAllProviders() {
return Object.keys(MODEL_MAP || {}); return Object.keys(MODEL_MAP || {});
} }
function getBaseUrlForRole(role, explicitRoot = null) {
const roleConfig = getModelConfigForRole(role, explicitRoot);
return roleConfig && typeof roleConfig.baseUrl === 'string'
? roleConfig.baseUrl
: undefined;
}
export { export {
// Core config access // Core config access
getConfig, getConfig,
@@ -704,6 +711,7 @@ export {
getFallbackModelId, getFallbackModelId,
getFallbackMaxTokens, getFallbackMaxTokens,
getFallbackTemperature, getFallbackTemperature,
getBaseUrlForRole,
// Global setting getters (No env var overrides) // Global setting getters (No env var overrides)
getLogLevel, getLogLevel,

View File

@@ -5,7 +5,7 @@
* using the Vercel AI SDK. * using the Vercel AI SDK.
*/ */
import { createAnthropic } from '@ai-sdk/anthropic'; import { createAnthropic } from '@ai-sdk/anthropic';
import { generateText, streamText, generateObject, streamObject } from 'ai'; import { generateText, streamText, generateObject } from 'ai';
import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible
// TODO: Implement standardized functions for generateText, streamText, generateObject // TODO: Implement standardized functions for generateText, streamText, generateObject
@@ -17,7 +17,7 @@ import { log } from '../../scripts/modules/utils.js'; // Assuming utils is acces
// Remove the global variable and caching logic // Remove the global variable and caching logic
// let anthropicClient; // let anthropicClient;
function getClient(apiKey) { function getClient(apiKey, baseUrl) {
if (!apiKey) { if (!apiKey) {
// In a real scenario, this would use the config resolver. // In a real scenario, this would use the config resolver.
// Throwing error here if key isn't passed for simplicity. // Throwing error here if key isn't passed for simplicity.
@@ -30,14 +30,12 @@ function getClient(apiKey) {
// Create and return a new instance directly with standard version header // Create and return a new instance directly with standard version header
return createAnthropic({ return createAnthropic({
apiKey: apiKey, apiKey: apiKey,
baseURL: 'https://api.anthropic.com/v1', ...(baseUrl && { baseURL: baseUrl }),
// Use standard version header instead of beta // Use standard version header instead of beta
headers: { headers: {
'anthropic-beta': 'output-128k-2025-02-19' 'anthropic-beta': 'output-128k-2025-02-19'
} }
}); });
// }
// return anthropicClient;
} }
// --- Standardized Service Function Implementations --- // --- Standardized Service Function Implementations ---
@@ -51,6 +49,7 @@ function getClient(apiKey) {
* @param {Array<object>} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]). * @param {Array<object>} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]).
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - The base URL for the Anthropic API.
* @returns {Promise<string>} The generated text content. * @returns {Promise<string>} The generated text content.
* @throws {Error} If the API call fails. * @throws {Error} If the API call fails.
*/ */
@@ -59,11 +58,12 @@ export async function generateAnthropicText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Generating Anthropic text with model: ${modelId}`); log('debug', `Generating Anthropic text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const result = await generateText({ const result = await generateText({
model: client(modelId), model: client(modelId),
messages: messages, messages: messages,
@@ -93,6 +93,7 @@ export async function generateAnthropicText({
* @param {Array<object>} params.messages - The messages array. * @param {Array<object>} params.messages - The messages array.
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - The base URL for the Anthropic API.
* @returns {Promise<object>} The full stream result object from the Vercel AI SDK. * @returns {Promise<object>} The full stream result object from the Vercel AI SDK.
* @throws {Error} If the API call fails to initiate the stream. * @throws {Error} If the API call fails to initiate the stream.
*/ */
@@ -101,20 +102,20 @@ export async function streamAnthropicText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Streaming Anthropic text with model: ${modelId}`); log('debug', `Streaming Anthropic text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
// --- DEBUG LOGGING --- >>
log( log(
'debug', 'debug',
'[streamAnthropicText] Parameters received by streamText:', '[streamAnthropicText] Parameters received by streamText:',
JSON.stringify( JSON.stringify(
{ {
modelId: modelId, // Log modelId being used modelId: modelId,
messages: messages, // Log the messages array messages: messages,
maxTokens: maxTokens, maxTokens: maxTokens,
temperature: temperature temperature: temperature
}, },
@@ -122,25 +123,19 @@ export async function streamAnthropicText({
2 2
) )
); );
// --- << DEBUG LOGGING ---
const stream = await streamText({ const stream = await streamText({
model: client(modelId), model: client(modelId),
messages: messages, messages: messages,
maxTokens: maxTokens, maxTokens: maxTokens,
temperature: temperature temperature: temperature
// Beta header moved to client initialization
// TODO: Add other relevant parameters // TODO: Add other relevant parameters
}); });
// *** RETURN THE FULL STREAM OBJECT, NOT JUST stream.textStream *** // *** RETURN THE FULL STREAM OBJECT, NOT JUST stream.textStream ***
return stream; return stream;
} catch (error) { } catch (error) {
log( log('error', `Anthropic streamText failed: ${error.message}`, error.stack);
'error',
`Anthropic streamText failed: ${error.message}`,
error.stack // Log stack trace for more details
);
throw error; throw error;
} }
} }
@@ -160,6 +155,7 @@ export async function streamAnthropicText({
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {number} [params.maxRetries] - Max retries for validation/generation. * @param {number} [params.maxRetries] - Max retries for validation/generation.
* @param {string} [params.baseUrl] - The base URL for the Anthropic API.
* @returns {Promise<object>} The generated object matching the schema. * @returns {Promise<object>} The generated object matching the schema.
* @throws {Error} If generation or validation fails. * @throws {Error} If generation or validation fails.
*/ */
@@ -171,24 +167,22 @@ export async function generateAnthropicObject({
objectName = 'generated_object', objectName = 'generated_object',
maxTokens, maxTokens,
temperature, temperature,
maxRetries = 3 maxRetries = 3,
baseUrl
}) { }) {
log( log(
'debug', 'debug',
`Generating Anthropic object ('${objectName}') with model: ${modelId}` `Generating Anthropic object ('${objectName}') with model: ${modelId}`
); );
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
// Log basic debug info
log( log(
'debug', 'debug',
`Using maxTokens: ${maxTokens}, temperature: ${temperature}, model: ${modelId}` `Using maxTokens: ${maxTokens}, temperature: ${temperature}, model: ${modelId}`
); );
const result = await generateObject({ const result = await generateObject({
model: client(modelId), model: client(modelId),
mode: 'tool', // Anthropic generally uses 'tool' mode for structured output mode: 'tool',
schema: schema, schema: schema,
messages: messages, messages: messages,
tool: { tool: {
@@ -199,14 +193,12 @@ export async function generateAnthropicObject({
temperature: temperature, temperature: temperature,
maxRetries: maxRetries maxRetries: maxRetries
}); });
log( log(
'debug', 'debug',
`Anthropic generateObject result received. Tokens: ${result.usage.completionTokens}/${result.usage.promptTokens}` `Anthropic generateObject result received. Tokens: ${result.usage.completionTokens}/${result.usage.promptTokens}`
); );
return result.object; return result.object;
} catch (error) { } catch (error) {
// Simple error logging
log( log(
'error', 'error',
`Anthropic generateObject ('${objectName}') failed: ${error.message}` `Anthropic generateObject ('${objectName}') failed: ${error.message}`

View File

@@ -12,6 +12,16 @@ import { log } from '../../scripts/modules/utils.js'; // Import logging utility
const DEFAULT_MODEL = 'gemini-2.0-pro'; // Or a suitable default const DEFAULT_MODEL = 'gemini-2.0-pro'; // Or a suitable default
const DEFAULT_TEMPERATURE = 0.2; // Or a suitable default const DEFAULT_TEMPERATURE = 0.2; // Or a suitable default
function getClient(apiKey, baseUrl) {
if (!apiKey) {
throw new Error('Google API key is required.');
}
return createGoogleGenerativeAI({
apiKey: apiKey,
...(baseUrl && { baseURL: baseUrl })
});
}
/** /**
* Generates text using a Google AI model. * Generates text using a Google AI model.
* *
@@ -29,7 +39,8 @@ async function generateGoogleText({
modelId = DEFAULT_MODEL, modelId = DEFAULT_MODEL,
temperature = DEFAULT_TEMPERATURE, temperature = DEFAULT_TEMPERATURE,
messages, messages,
maxTokens // Note: Vercel SDK might handle this differently, needs verification maxTokens,
baseUrl
}) { }) {
if (!apiKey) { if (!apiKey) {
throw new Error('Google API key is required.'); throw new Error('Google API key is required.');
@@ -37,28 +48,21 @@ async function generateGoogleText({
log('info', `Generating text with Google model: ${modelId}`); log('info', `Generating text with Google model: ${modelId}`);
try { try {
// const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation const googleProvider = getClient(apiKey, baseUrl);
const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation const model = googleProvider(modelId);
// const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval
const model = googleProvider(modelId); // Correct model retrieval
// Construct payload suitable for Vercel SDK's generateText
// Note: The exact structure might depend on how messages are passed
const result = await generateText({ const result = await generateText({
model, // Pass the model instance model,
messages, // Pass the messages array directly messages,
temperature, temperature,
maxOutputTokens: maxTokens // Map to correct Vercel SDK param if available maxOutputTokens: maxTokens
}); });
return result.text;
// Assuming result structure provides text directly or within a property
return result.text; // Adjust based on actual SDK response
} catch (error) { } catch (error) {
log( log(
'error', 'error',
`Error generating text with Google (${modelId}): ${error.message}` `Error generating text with Google (${modelId}): ${error.message}`
); );
throw error; // Re-throw for unified service handler throw error;
} }
} }
@@ -79,7 +83,8 @@ async function streamGoogleText({
modelId = DEFAULT_MODEL, modelId = DEFAULT_MODEL,
temperature = DEFAULT_TEMPERATURE, temperature = DEFAULT_TEMPERATURE,
messages, messages,
maxTokens maxTokens,
baseUrl
}) { }) {
if (!apiKey) { if (!apiKey) {
throw new Error('Google API key is required.'); throw new Error('Google API key is required.');
@@ -87,19 +92,15 @@ async function streamGoogleText({
log('info', `Streaming text with Google model: ${modelId}`); log('info', `Streaming text with Google model: ${modelId}`);
try { try {
// const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation const googleProvider = getClient(apiKey, baseUrl);
const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation const model = googleProvider(modelId);
// const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval
const model = googleProvider(modelId); // Correct model retrieval
const stream = await streamText({ const stream = await streamText({
model, // Pass the model instance model,
messages, messages,
temperature, temperature,
maxOutputTokens: maxTokens maxOutputTokens: maxTokens
}); });
return stream;
return stream; // Return the stream directly
} catch (error) { } catch (error) {
log( log(
'error', 'error',
@@ -130,7 +131,8 @@ async function generateGoogleObject({
messages, messages,
schema, schema,
objectName, // Note: Vercel SDK might use this differently or not at all objectName, // Note: Vercel SDK might use this differently or not at all
maxTokens maxTokens,
baseUrl
}) { }) {
if (!apiKey) { if (!apiKey) {
throw new Error('Google API key is required.'); throw new Error('Google API key is required.');
@@ -138,23 +140,16 @@ async function generateGoogleObject({
log('info', `Generating object with Google model: ${modelId}`); log('info', `Generating object with Google model: ${modelId}`);
try { try {
// const google = new GoogleGenerativeAI({ apiKey }); // Incorrect instantiation const googleProvider = getClient(apiKey, baseUrl);
const googleProvider = createGoogleGenerativeAI({ apiKey }); // Correct instantiation const model = googleProvider(modelId);
// const model = google.getGenerativeModel({ model: modelId }); // Incorrect model retrieval
const model = googleProvider(modelId); // Correct model retrieval
const { object } = await generateObject({ const { object } = await generateObject({
model, // Pass the model instance model,
schema, schema,
messages, messages,
temperature, temperature,
maxOutputTokens: maxTokens maxOutputTokens: maxTokens
// Note: 'objectName' or 'mode' might not be directly applicable here
// depending on how `@ai-sdk/google` handles `generateObject`.
// Check SDK docs if specific tool calling/JSON mode needs explicit setup.
}); });
return object;
return object; // Return the parsed object
} catch (error) { } catch (error) {
log( log(
'error', 'error',

View File

@@ -1,16 +1,26 @@
import { createOpenAI, openai } from '@ai-sdk/openai'; // Using openai provider from Vercel AI SDK import { createOpenAI } from '@ai-sdk/openai'; // Using openai provider from Vercel AI SDK
import { generateText, streamText, generateObject } from 'ai'; // Import necessary functions from 'ai' import { generateObject } from 'ai'; // Import necessary functions from 'ai'
import { log } from '../../scripts/modules/utils.js'; import { log } from '../../scripts/modules/utils.js';
function getClient(apiKey, baseUrl) {
if (!apiKey) {
throw new Error('OpenAI API key is required.');
}
return createOpenAI({
apiKey: apiKey,
...(baseUrl && { baseURL: baseUrl })
});
}
/** /**
* Generates text using OpenAI models via Vercel AI SDK. * Generates text using OpenAI models via Vercel AI SDK.
* *
* @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature. * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature, baseUrl.
* @returns {Promise<string>} The generated text content. * @returns {Promise<string>} The generated text content.
* @throws {Error} If API call fails. * @throws {Error} If API call fails.
*/ */
export async function generateOpenAIText(params) { export async function generateOpenAIText(params) {
const { apiKey, modelId, messages, maxTokens, temperature } = params; const { apiKey, modelId, messages, maxTokens, temperature, baseUrl } = params;
log('debug', `generateOpenAIText called with model: ${modelId}`); log('debug', `generateOpenAIText called with model: ${modelId}`);
if (!apiKey) { if (!apiKey) {
@@ -23,18 +33,15 @@ export async function generateOpenAIText(params) {
throw new Error('Invalid or empty messages array provided for OpenAI.'); throw new Error('Invalid or empty messages array provided for OpenAI.');
} }
const openaiClient = createOpenAI({ apiKey }); const openaiClient = getClient(apiKey, baseUrl);
try { try {
const result = await openaiClient.chat(messages, { const result = await openaiClient.chat(messages, {
// Updated: Use openaiClient.chat directly
model: modelId, model: modelId,
max_tokens: maxTokens, max_tokens: maxTokens,
temperature temperature
}); });
// Adjust based on actual Vercel SDK response structure for openaiClient.chat
// This might need refinement based on testing the SDK's output.
const textContent = result?.choices?.[0]?.message?.content?.trim(); const textContent = result?.choices?.[0]?.message?.content?.trim();
if (!textContent) { if (!textContent) {
@@ -65,12 +72,12 @@ export async function generateOpenAIText(params) {
/** /**
* Streams text using OpenAI models via Vercel AI SDK. * Streams text using OpenAI models via Vercel AI SDK.
* *
* @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature. * @param {object} params - Parameters including apiKey, modelId, messages, maxTokens, temperature, baseUrl.
* @returns {Promise<ReadableStream>} A readable stream of text deltas. * @returns {Promise<ReadableStream>} A readable stream of text deltas.
* @throws {Error} If API call fails. * @throws {Error} If API call fails.
*/ */
export async function streamOpenAIText(params) { export async function streamOpenAIText(params) {
const { apiKey, modelId, messages, maxTokens, temperature } = params; const { apiKey, modelId, messages, maxTokens, temperature, baseUrl } = params;
log('debug', `streamOpenAIText called with model: ${modelId}`); log('debug', `streamOpenAIText called with model: ${modelId}`);
if (!apiKey) { if (!apiKey) {
@@ -85,12 +92,10 @@ export async function streamOpenAIText(params) {
); );
} }
const openaiClient = createOpenAI({ apiKey }); const openaiClient = getClient(apiKey, baseUrl);
try { try {
// Use the streamText function from Vercel AI SDK core
const stream = await openaiClient.chat.stream(messages, { const stream = await openaiClient.chat.stream(messages, {
// Updated: Use openaiClient.chat.stream
model: modelId, model: modelId,
max_tokens: maxTokens, max_tokens: maxTokens,
temperature temperature
@@ -100,7 +105,6 @@ export async function streamOpenAIText(params) {
'debug', 'debug',
`OpenAI streamText initiated successfully for model: ${modelId}` `OpenAI streamText initiated successfully for model: ${modelId}`
); );
// The Vercel SDK's streamText should directly return the stream object
return stream; return stream;
} catch (error) { } catch (error) {
log( log(
@@ -117,7 +121,7 @@ export async function streamOpenAIText(params) {
/** /**
* Generates structured objects using OpenAI models via Vercel AI SDK. * Generates structured objects using OpenAI models via Vercel AI SDK.
* *
* @param {object} params - Parameters including apiKey, modelId, messages, schema, objectName, maxTokens, temperature. * @param {object} params - Parameters including apiKey, modelId, messages, schema, objectName, maxTokens, temperature, baseUrl.
* @returns {Promise<object>} The generated object matching the schema. * @returns {Promise<object>} The generated object matching the schema.
* @throws {Error} If API call fails or object generation fails. * @throws {Error} If API call fails or object generation fails.
*/ */
@@ -129,7 +133,8 @@ export async function generateOpenAIObject(params) {
schema, schema,
objectName, objectName,
maxTokens, maxTokens,
temperature temperature,
baseUrl
} = params; } = params;
log( log(
'debug', 'debug',
@@ -145,10 +150,9 @@ export async function generateOpenAIObject(params) {
if (!objectName) if (!objectName)
throw new Error('Object name is required for OpenAI object generation.'); throw new Error('Object name is required for OpenAI object generation.');
const openaiClient = createOpenAI({ apiKey }); const openaiClient = getClient(apiKey, baseUrl);
try { try {
// Use the imported generateObject function from 'ai' package
const result = await generateObject({ const result = await generateObject({
model: openaiClient(modelId), model: openaiClient(modelId),
schema: schema, schema: schema,

View File

@@ -2,6 +2,14 @@ import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { generateText, streamText, generateObject } from 'ai'; import { generateText, streamText, generateObject } from 'ai';
import { log } from '../../scripts/modules/utils.js'; // Assuming utils.js is in scripts/modules import { log } from '../../scripts/modules/utils.js'; // Assuming utils.js is in scripts/modules
function getClient(apiKey, baseUrl) {
if (!apiKey) throw new Error('OpenRouter API key is required.');
return createOpenRouter({
apiKey,
...(baseUrl && { baseURL: baseUrl })
});
}
/** /**
* Generates text using an OpenRouter chat model. * Generates text using an OpenRouter chat model.
* *
@@ -11,6 +19,7 @@ import { log } from '../../scripts/modules/utils.js'; // Assuming utils.js is in
* @param {Array<object>} params.messages - Array of message objects (system, user, assistant). * @param {Array<object>} params.messages - Array of message objects (system, user, assistant).
* @param {number} [params.maxTokens] - Maximum tokens to generate. * @param {number} [params.maxTokens] - Maximum tokens to generate.
* @param {number} [params.temperature] - Sampling temperature. * @param {number} [params.temperature] - Sampling temperature.
* @param {string} [params.baseUrl] - Base URL for the OpenRouter API.
* @returns {Promise<string>} The generated text content. * @returns {Promise<string>} The generated text content.
* @throws {Error} If the API call fails. * @throws {Error} If the API call fails.
*/ */
@@ -20,6 +29,7 @@ async function generateOpenRouterText({
messages, messages,
maxTokens, maxTokens,
temperature, temperature,
baseUrl,
...rest // Capture any other Vercel AI SDK compatible parameters ...rest // Capture any other Vercel AI SDK compatible parameters
}) { }) {
if (!apiKey) throw new Error('OpenRouter API key is required.'); if (!apiKey) throw new Error('OpenRouter API key is required.');
@@ -28,7 +38,7 @@ async function generateOpenRouterText({
throw new Error('Messages array cannot be empty.'); throw new Error('Messages array cannot be empty.');
try { try {
const openrouter = createOpenRouter({ apiKey }); const openrouter = getClient(apiKey, baseUrl);
const model = openrouter.chat(modelId); // Assuming chat model const model = openrouter.chat(modelId); // Assuming chat model
const { text } = await generateText({ const { text } = await generateText({
@@ -58,6 +68,7 @@ async function generateOpenRouterText({
* @param {Array<object>} params.messages - Array of message objects (system, user, assistant). * @param {Array<object>} params.messages - Array of message objects (system, user, assistant).
* @param {number} [params.maxTokens] - Maximum tokens to generate. * @param {number} [params.maxTokens] - Maximum tokens to generate.
* @param {number} [params.temperature] - Sampling temperature. * @param {number} [params.temperature] - Sampling temperature.
* @param {string} [params.baseUrl] - Base URL for the OpenRouter API.
* @returns {Promise<ReadableStream<string>>} A readable stream of text deltas. * @returns {Promise<ReadableStream<string>>} A readable stream of text deltas.
* @throws {Error} If the API call fails. * @throws {Error} If the API call fails.
*/ */
@@ -67,6 +78,7 @@ async function streamOpenRouterText({
messages, messages,
maxTokens, maxTokens,
temperature, temperature,
baseUrl,
...rest ...rest
}) { }) {
if (!apiKey) throw new Error('OpenRouter API key is required.'); if (!apiKey) throw new Error('OpenRouter API key is required.');
@@ -75,7 +87,7 @@ async function streamOpenRouterText({
throw new Error('Messages array cannot be empty.'); throw new Error('Messages array cannot be empty.');
try { try {
const openrouter = createOpenRouter({ apiKey }); const openrouter = getClient(apiKey, baseUrl);
const model = openrouter.chat(modelId); const model = openrouter.chat(modelId);
// Directly return the stream from the Vercel AI SDK function // Directly return the stream from the Vercel AI SDK function
@@ -108,6 +120,7 @@ async function streamOpenRouterText({
* @param {number} [params.maxRetries=3] - Max retries for object generation. * @param {number} [params.maxRetries=3] - Max retries for object generation.
* @param {number} [params.maxTokens] - Maximum tokens. * @param {number} [params.maxTokens] - Maximum tokens.
* @param {number} [params.temperature] - Temperature. * @param {number} [params.temperature] - Temperature.
* @param {string} [params.baseUrl] - Base URL for the OpenRouter API.
* @returns {Promise<object>} The generated object matching the schema. * @returns {Promise<object>} The generated object matching the schema.
* @throws {Error} If the API call fails or validation fails. * @throws {Error} If the API call fails or validation fails.
*/ */
@@ -120,6 +133,7 @@ async function generateOpenRouterObject({
maxRetries = 3, maxRetries = 3,
maxTokens, maxTokens,
temperature, temperature,
baseUrl,
...rest ...rest
}) { }) {
if (!apiKey) throw new Error('OpenRouter API key is required.'); if (!apiKey) throw new Error('OpenRouter API key is required.');
@@ -129,7 +143,7 @@ async function generateOpenRouterObject({
throw new Error('Messages array cannot be empty.'); throw new Error('Messages array cannot be empty.');
try { try {
const openrouter = createOpenRouter({ apiKey }); const openrouter = getClient(apiKey, baseUrl);
const model = openrouter.chat(modelId); const model = openrouter.chat(modelId);
const { object } = await generateObject({ const { object } = await generateObject({

View File

@@ -10,13 +10,13 @@ import { log } from '../../scripts/modules/utils.js';
// --- Client Instantiation --- // --- Client Instantiation ---
// Similar to Anthropic, this expects the resolved API key to be passed in. // Similar to Anthropic, this expects the resolved API key to be passed in.
function getClient(apiKey) { function getClient(apiKey, baseUrl) {
if (!apiKey) { if (!apiKey) {
throw new Error('Perplexity API key is required.'); throw new Error('Perplexity API key is required.');
} }
// Create and return a new instance directly
return createPerplexity({ return createPerplexity({
apiKey: apiKey apiKey: apiKey,
...(baseUrl && { baseURL: baseUrl })
}); });
} }
@@ -31,6 +31,7 @@ function getClient(apiKey) {
* @param {Array<object>} params.messages - The messages array. * @param {Array<object>} params.messages - The messages array.
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - Base URL for the Perplexity API.
* @returns {Promise<string>} The generated text content. * @returns {Promise<string>} The generated text content.
* @throws {Error} If the API call fails. * @throws {Error} If the API call fails.
*/ */
@@ -39,11 +40,12 @@ export async function generatePerplexityText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Generating Perplexity text with model: ${modelId}`); log('debug', `Generating Perplexity text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const result = await generateText({ const result = await generateText({
model: client(modelId), model: client(modelId),
messages: messages, messages: messages,
@@ -70,6 +72,7 @@ export async function generatePerplexityText({
* @param {Array<object>} params.messages - The messages array. * @param {Array<object>} params.messages - The messages array.
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - Base URL for the Perplexity API.
* @returns {Promise<object>} The full stream result object from the Vercel AI SDK. * @returns {Promise<object>} The full stream result object from the Vercel AI SDK.
* @throws {Error} If the API call fails to initiate the stream. * @throws {Error} If the API call fails to initiate the stream.
*/ */
@@ -78,11 +81,12 @@ export async function streamPerplexityText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Streaming Perplexity text with model: ${modelId}`); log('debug', `Streaming Perplexity text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const stream = await streamText({ const stream = await streamText({
model: client(modelId), model: client(modelId),
messages: messages, messages: messages,
@@ -112,6 +116,7 @@ export async function streamPerplexityText({
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {number} [params.maxRetries] - Max retries for validation/generation. * @param {number} [params.maxRetries] - Max retries for validation/generation.
* @param {string} [params.baseUrl] - Base URL for the Perplexity API.
* @returns {Promise<object>} The generated object matching the schema. * @returns {Promise<object>} The generated object matching the schema.
* @throws {Error} If generation or validation fails or is unsupported. * @throws {Error} If generation or validation fails or is unsupported.
*/ */
@@ -123,7 +128,8 @@ export async function generatePerplexityObject({
objectName = 'generated_object', objectName = 'generated_object',
maxTokens, maxTokens,
temperature, temperature,
maxRetries = 1 // Lower retries as support might be limited maxRetries = 1,
baseUrl
}) { }) {
log( log(
'debug', 'debug',
@@ -134,8 +140,7 @@ export async function generatePerplexityObject({
'generateObject support for Perplexity might be limited or experimental.' 'generateObject support for Perplexity might be limited or experimental.'
); );
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
// Attempt using generateObject, but be prepared for potential issues
const result = await generateObject({ const result = await generateObject({
model: client(modelId), model: client(modelId),
schema: schema, schema: schema,

View File

@@ -9,14 +9,13 @@ import { generateText, streamText, generateObject } from 'ai'; // Only import wh
import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible import { log } from '../../scripts/modules/utils.js'; // Assuming utils is accessible
// --- Client Instantiation --- // --- Client Instantiation ---
function getClient(apiKey) { function getClient(apiKey, baseUrl) {
if (!apiKey) { if (!apiKey) {
throw new Error('xAI API key is required.'); throw new Error('xAI API key is required.');
} }
// Create and return a new instance directly
return createXai({ return createXai({
apiKey: apiKey apiKey: apiKey,
// Add baseURL or other options if needed later ...(baseUrl && { baseURL: baseUrl })
}); });
} }
@@ -31,6 +30,7 @@ function getClient(apiKey) {
* @param {Array<object>} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]). * @param {Array<object>} params.messages - The messages array (e.g., [{ role: 'user', content: '...' }]).
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - The base URL for the xAI API.
* @returns {Promise<string>} The generated text content. * @returns {Promise<string>} The generated text content.
* @throws {Error} If the API call fails. * @throws {Error} If the API call fails.
*/ */
@@ -39,13 +39,14 @@ export async function generateXaiText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Generating xAI text with model: ${modelId}`); log('debug', `Generating xAI text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const result = await generateText({ const result = await generateText({
model: client(modelId), // Correct model invocation model: client(modelId),
messages: messages, messages: messages,
maxTokens: maxTokens, maxTokens: maxTokens,
temperature: temperature temperature: temperature
@@ -70,6 +71,7 @@ export async function generateXaiText({
* @param {Array<object>} params.messages - The messages array. * @param {Array<object>} params.messages - The messages array.
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {string} [params.baseUrl] - The base URL for the xAI API.
* @returns {Promise<object>} The full stream result object from the Vercel AI SDK. * @returns {Promise<object>} The full stream result object from the Vercel AI SDK.
* @throws {Error} If the API call fails to initiate the stream. * @throws {Error} If the API call fails to initiate the stream.
*/ */
@@ -78,18 +80,19 @@ export async function streamXaiText({
modelId, modelId,
messages, messages,
maxTokens, maxTokens,
temperature temperature,
baseUrl
}) { }) {
log('debug', `Streaming xAI text with model: ${modelId}`); log('debug', `Streaming xAI text with model: ${modelId}`);
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const stream = await streamText({ const stream = await streamText({
model: client(modelId), // Correct model invocation model: client(modelId),
messages: messages, messages: messages,
maxTokens: maxTokens, maxTokens: maxTokens,
temperature: temperature temperature: temperature
}); });
return stream; // Return the full stream object return stream;
} catch (error) { } catch (error) {
log('error', `xAI streamText failed: ${error.message}`, error.stack); log('error', `xAI streamText failed: ${error.message}`, error.stack);
throw error; throw error;
@@ -110,6 +113,7 @@ export async function streamXaiText({
* @param {number} [params.maxTokens] - Maximum tokens for the response. * @param {number} [params.maxTokens] - Maximum tokens for the response.
* @param {number} [params.temperature] - Temperature for generation. * @param {number} [params.temperature] - Temperature for generation.
* @param {number} [params.maxRetries] - Max retries for validation/generation. * @param {number} [params.maxRetries] - Max retries for validation/generation.
* @param {string} [params.baseUrl] - The base URL for the xAI API.
* @returns {Promise<object>} The generated object matching the schema. * @returns {Promise<object>} The generated object matching the schema.
* @throws {Error} If generation or validation fails. * @throws {Error} If generation or validation fails.
*/ */
@@ -121,16 +125,17 @@ export async function generateXaiObject({
objectName = 'generated_xai_object', objectName = 'generated_xai_object',
maxTokens, maxTokens,
temperature, temperature,
maxRetries = 3 maxRetries = 3,
baseUrl
}) { }) {
log( log(
'warn', // Log warning as this is likely unsupported 'warn',
`Attempting to generate xAI object ('${objectName}') with model: ${modelId}. This may not be supported by the provider.` `Attempting to generate xAI object ('${objectName}') with model: ${modelId}. This may not be supported by the provider.`
); );
try { try {
const client = getClient(apiKey); const client = getClient(apiKey, baseUrl);
const result = await generateObject({ const result = await generateObject({
model: client(modelId), // Correct model invocation model: client(modelId),
// Note: mode might need adjustment if xAI ever supports object generation differently // Note: mode might need adjustment if xAI ever supports object generation differently
mode: 'tool', mode: 'tool',
schema: schema, schema: schema,
@@ -153,6 +158,6 @@ export async function generateXaiObject({
'error', 'error',
`xAI generateObject ('${objectName}') failed: ${error.message}. (Likely unsupported by provider)` `xAI generateObject ('${objectName}') failed: ${error.message}. (Likely unsupported by provider)`
); );
throw error; // Re-throw the error throw error;
} }
} }

View File

@@ -8,6 +8,7 @@ const mockGetResearchModelId = jest.fn();
const mockGetFallbackProvider = jest.fn(); const mockGetFallbackProvider = jest.fn();
const mockGetFallbackModelId = jest.fn(); const mockGetFallbackModelId = jest.fn();
const mockGetParametersForRole = jest.fn(); const mockGetParametersForRole = jest.fn();
const mockGetBaseUrlForRole = jest.fn();
jest.unstable_mockModule('../../scripts/modules/config-manager.js', () => ({ jest.unstable_mockModule('../../scripts/modules/config-manager.js', () => ({
getMainProvider: mockGetMainProvider, getMainProvider: mockGetMainProvider,
@@ -16,7 +17,8 @@ jest.unstable_mockModule('../../scripts/modules/config-manager.js', () => ({
getResearchModelId: mockGetResearchModelId, getResearchModelId: mockGetResearchModelId,
getFallbackProvider: mockGetFallbackProvider, getFallbackProvider: mockGetFallbackProvider,
getFallbackModelId: mockGetFallbackModelId, getFallbackModelId: mockGetFallbackModelId,
getParametersForRole: mockGetParametersForRole getParametersForRole: mockGetParametersForRole,
getBaseUrlForRole: mockGetBaseUrlForRole
})); }));
// Mock AI Provider Modules // Mock AI Provider Modules