Feat/add.azure.and.other.providers (#607)
* fix: claude-4 not having the right max_tokens * feat: add bedrock support * chore: fix package-lock.json * fix: rename baseUrl to baseURL * feat: add azure support * fix: final touches of azure integration * feat: add google vertex provider * chore: fix tests and refactor task-manager.test.js * chore: move task 92 to 94
This commit is contained in:
@@ -19,18 +19,41 @@ import {
|
||||
MODEL_MAP,
|
||||
getDebugFlag,
|
||||
getBaseUrlForRole,
|
||||
isApiKeySet
|
||||
isApiKeySet,
|
||||
getOllamaBaseURL,
|
||||
getAzureBaseURL,
|
||||
getVertexProjectId,
|
||||
getVertexLocation
|
||||
} from './config-manager.js';
|
||||
import { log, findProjectRoot, resolveEnvVariable } from './utils.js';
|
||||
|
||||
import * as anthropic from '../../src/ai-providers/anthropic.js';
|
||||
import * as perplexity from '../../src/ai-providers/perplexity.js';
|
||||
import * as google from '../../src/ai-providers/google.js';
|
||||
import * as openai from '../../src/ai-providers/openai.js';
|
||||
import * as xai from '../../src/ai-providers/xai.js';
|
||||
import * as openrouter from '../../src/ai-providers/openrouter.js';
|
||||
import * as ollama from '../../src/ai-providers/ollama.js';
|
||||
// TODO: Import other provider modules when implemented (ollama, etc.)
|
||||
// Import provider classes
|
||||
import {
|
||||
AnthropicAIProvider,
|
||||
PerplexityAIProvider,
|
||||
GoogleAIProvider,
|
||||
OpenAIProvider,
|
||||
XAIProvider,
|
||||
OpenRouterAIProvider,
|
||||
OllamaAIProvider,
|
||||
BedrockAIProvider,
|
||||
AzureProvider,
|
||||
VertexAIProvider
|
||||
} from '../../src/ai-providers/index.js';
|
||||
|
||||
// Create provider instances
|
||||
const PROVIDERS = {
|
||||
anthropic: new AnthropicAIProvider(),
|
||||
perplexity: new PerplexityAIProvider(),
|
||||
google: new GoogleAIProvider(),
|
||||
openai: new OpenAIProvider(),
|
||||
xai: new XAIProvider(),
|
||||
openrouter: new OpenRouterAIProvider(),
|
||||
ollama: new OllamaAIProvider(),
|
||||
bedrock: new BedrockAIProvider(),
|
||||
azure: new AzureProvider(),
|
||||
vertex: new VertexAIProvider()
|
||||
};
|
||||
|
||||
// Helper function to get cost for a specific model
|
||||
function _getCostForModel(providerName, modelId) {
|
||||
@@ -62,51 +85,6 @@ function _getCostForModel(providerName, modelId) {
|
||||
};
|
||||
}
|
||||
|
||||
// --- Provider Function Map ---
|
||||
// Maps provider names (lowercase) to their respective service functions
|
||||
const PROVIDER_FUNCTIONS = {
|
||||
anthropic: {
|
||||
generateText: anthropic.generateAnthropicText,
|
||||
streamText: anthropic.streamAnthropicText,
|
||||
generateObject: anthropic.generateAnthropicObject
|
||||
},
|
||||
perplexity: {
|
||||
generateText: perplexity.generatePerplexityText,
|
||||
streamText: perplexity.streamPerplexityText,
|
||||
generateObject: perplexity.generatePerplexityObject
|
||||
},
|
||||
google: {
|
||||
// Add Google entry
|
||||
generateText: google.generateGoogleText,
|
||||
streamText: google.streamGoogleText,
|
||||
generateObject: google.generateGoogleObject
|
||||
},
|
||||
openai: {
|
||||
// ADD: OpenAI entry
|
||||
generateText: openai.generateOpenAIText,
|
||||
streamText: openai.streamOpenAIText,
|
||||
generateObject: openai.generateOpenAIObject
|
||||
},
|
||||
xai: {
|
||||
// ADD: xAI entry
|
||||
generateText: xai.generateXaiText,
|
||||
streamText: xai.streamXaiText,
|
||||
generateObject: xai.generateXaiObject // Note: Object generation might be unsupported
|
||||
},
|
||||
openrouter: {
|
||||
// ADD: OpenRouter entry
|
||||
generateText: openrouter.generateOpenRouterText,
|
||||
streamText: openrouter.streamOpenRouterText,
|
||||
generateObject: openrouter.generateOpenRouterObject
|
||||
},
|
||||
ollama: {
|
||||
generateText: ollama.generateOllamaText,
|
||||
streamText: ollama.streamOllamaText,
|
||||
generateObject: ollama.generateOllamaObject
|
||||
}
|
||||
// TODO: Add entries for ollama, etc. when implemented
|
||||
};
|
||||
|
||||
// --- Configuration for Retries ---
|
||||
const MAX_RETRIES = 2;
|
||||
const INITIAL_RETRY_DELAY_MS = 1000;
|
||||
@@ -191,7 +169,9 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
|
||||
azure: 'AZURE_OPENAI_API_KEY',
|
||||
openrouter: 'OPENROUTER_API_KEY',
|
||||
xai: 'XAI_API_KEY',
|
||||
ollama: 'OLLAMA_API_KEY'
|
||||
ollama: 'OLLAMA_API_KEY',
|
||||
bedrock: 'AWS_ACCESS_KEY_ID',
|
||||
vertex: 'GOOGLE_API_KEY'
|
||||
};
|
||||
|
||||
const envVarName = keyMap[providerName];
|
||||
@@ -203,12 +183,11 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
|
||||
|
||||
const apiKey = resolveEnvVariable(envVarName, session, projectRoot);
|
||||
|
||||
// Special handling for Ollama - API key is optional
|
||||
if (providerName === 'ollama') {
|
||||
// Special handling for providers that can use alternative auth
|
||||
if (providerName === 'ollama' || providerName === 'bedrock') {
|
||||
return apiKey || null;
|
||||
}
|
||||
|
||||
// For all other providers, API key is required
|
||||
if (!apiKey) {
|
||||
throw new Error(
|
||||
`Required API key ${envVarName} for provider '${providerName}' is not set in environment, session, or .env file.`
|
||||
@@ -229,14 +208,15 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
|
||||
* @throws {Error} If the call fails after all retries.
|
||||
*/
|
||||
async function _attemptProviderCallWithRetries(
|
||||
providerApiFn,
|
||||
provider,
|
||||
serviceType,
|
||||
callParams,
|
||||
providerName,
|
||||
modelId,
|
||||
attemptRole
|
||||
) {
|
||||
let retries = 0;
|
||||
const fnName = providerApiFn.name;
|
||||
const fnName = serviceType;
|
||||
|
||||
while (retries <= MAX_RETRIES) {
|
||||
try {
|
||||
@@ -247,8 +227,8 @@ async function _attemptProviderCallWithRetries(
|
||||
);
|
||||
}
|
||||
|
||||
// Call the specific provider function directly
|
||||
const result = await providerApiFn(callParams);
|
||||
// Call the appropriate method on the provider instance
|
||||
const result = await provider[serviceType](callParams);
|
||||
|
||||
if (getDebugFlag()) {
|
||||
log(
|
||||
@@ -350,9 +330,8 @@ async function _unifiedServiceRunner(serviceType, params) {
|
||||
modelId,
|
||||
apiKey,
|
||||
roleParams,
|
||||
providerFnSet,
|
||||
providerApiFn,
|
||||
baseUrl,
|
||||
provider,
|
||||
baseURL,
|
||||
providerResponse,
|
||||
telemetryData = null;
|
||||
|
||||
@@ -391,7 +370,20 @@ async function _unifiedServiceRunner(serviceType, params) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if API key is set for the current provider and role (excluding 'ollama')
|
||||
// Get provider instance
|
||||
provider = PROVIDERS[providerName?.toLowerCase()];
|
||||
if (!provider) {
|
||||
log(
|
||||
'warn',
|
||||
`Skipping role '${currentRole}': Provider '${providerName}' not supported.`
|
||||
);
|
||||
lastError =
|
||||
lastError ||
|
||||
new Error(`Unsupported provider configured: ${providerName}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check API key if needed
|
||||
if (providerName?.toLowerCase() !== 'ollama') {
|
||||
if (!isApiKeySet(providerName, session, effectiveProjectRoot)) {
|
||||
log(
|
||||
@@ -407,40 +399,70 @@ async function _unifiedServiceRunner(serviceType, params) {
|
||||
}
|
||||
}
|
||||
|
||||
// Get base URL if configured (optional for most providers)
|
||||
baseURL = getBaseUrlForRole(currentRole, effectiveProjectRoot);
|
||||
|
||||
// For Azure, use the global Azure base URL if role-specific URL is not configured
|
||||
if (providerName?.toLowerCase() === 'azure' && !baseURL) {
|
||||
baseURL = getAzureBaseURL(effectiveProjectRoot);
|
||||
log('debug', `Using global Azure base URL: ${baseURL}`);
|
||||
} else if (providerName?.toLowerCase() === 'ollama' && !baseURL) {
|
||||
// For Ollama, use the global Ollama base URL if role-specific URL is not configured
|
||||
baseURL = getOllamaBaseURL(effectiveProjectRoot);
|
||||
log('debug', `Using global Ollama base URL: ${baseURL}`);
|
||||
}
|
||||
|
||||
// Get AI parameters for the current role
|
||||
roleParams = getParametersForRole(currentRole, effectiveProjectRoot);
|
||||
baseUrl = getBaseUrlForRole(currentRole, effectiveProjectRoot);
|
||||
providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()];
|
||||
if (!providerFnSet) {
|
||||
log(
|
||||
'warn',
|
||||
`Skipping role '${currentRole}': Provider '${providerName}' not supported or map entry missing.`
|
||||
);
|
||||
lastError =
|
||||
lastError ||
|
||||
new Error(`Unsupported provider configured: ${providerName}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
providerApiFn = providerFnSet[serviceType];
|
||||
if (typeof providerApiFn !== 'function') {
|
||||
log(
|
||||
'warn',
|
||||
`Skipping role '${currentRole}': Service type '${serviceType}' not implemented for provider '${providerName}'.`
|
||||
);
|
||||
lastError =
|
||||
lastError ||
|
||||
new Error(
|
||||
`Service '${serviceType}' not implemented for provider ${providerName}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
apiKey = _resolveApiKey(
|
||||
providerName?.toLowerCase(),
|
||||
session,
|
||||
effectiveProjectRoot
|
||||
);
|
||||
|
||||
// Prepare provider-specific configuration
|
||||
let providerSpecificParams = {};
|
||||
|
||||
// Handle Vertex AI specific configuration
|
||||
if (providerName?.toLowerCase() === 'vertex') {
|
||||
// Get Vertex project ID and location
|
||||
const projectId =
|
||||
getVertexProjectId(effectiveProjectRoot) ||
|
||||
resolveEnvVariable(
|
||||
'VERTEX_PROJECT_ID',
|
||||
session,
|
||||
effectiveProjectRoot
|
||||
);
|
||||
|
||||
const location =
|
||||
getVertexLocation(effectiveProjectRoot) ||
|
||||
resolveEnvVariable(
|
||||
'VERTEX_LOCATION',
|
||||
session,
|
||||
effectiveProjectRoot
|
||||
) ||
|
||||
'us-central1';
|
||||
|
||||
// Get credentials path if available
|
||||
const credentialsPath = resolveEnvVariable(
|
||||
'GOOGLE_APPLICATION_CREDENTIALS',
|
||||
session,
|
||||
effectiveProjectRoot
|
||||
);
|
||||
|
||||
// Add Vertex-specific parameters
|
||||
providerSpecificParams = {
|
||||
projectId,
|
||||
location,
|
||||
...(credentialsPath && { credentials: { credentialsFromEnv: true } })
|
||||
};
|
||||
|
||||
log(
|
||||
'debug',
|
||||
`Using Vertex AI configuration: Project ID=${projectId}, Location=${location}`
|
||||
);
|
||||
}
|
||||
|
||||
const messages = [];
|
||||
if (systemPrompt) {
|
||||
messages.push({ role: 'system', content: systemPrompt });
|
||||
@@ -476,13 +498,15 @@ async function _unifiedServiceRunner(serviceType, params) {
|
||||
maxTokens: roleParams.maxTokens,
|
||||
temperature: roleParams.temperature,
|
||||
messages,
|
||||
baseUrl,
|
||||
...(baseURL && { baseURL }),
|
||||
...(serviceType === 'generateObject' && { schema, objectName }),
|
||||
...providerSpecificParams,
|
||||
...restApiParams
|
||||
};
|
||||
|
||||
providerResponse = await _attemptProviderCallWithRetries(
|
||||
providerApiFn,
|
||||
provider,
|
||||
serviceType,
|
||||
callParams,
|
||||
providerName,
|
||||
modelId,
|
||||
|
||||
Reference in New Issue
Block a user