Feat/add.azure.and.other.providers (#607)

* fix: claude-4 not having the right max_tokens

* feat: add bedrock support

* chore: fix package-lock.json

* fix: rename baseUrl to baseURL

* feat: add azure support

* fix: final touches of azure integration

* feat: add google vertex provider

* chore: fix tests and refactor task-manager.test.js

* chore: move task 92 to 94
This commit is contained in:
Ralph Khreish
2025-05-28 00:42:31 +02:00
committed by GitHub
parent 80735f9e60
commit 6a8a68e1a3
49 changed files with 12785 additions and 5015 deletions

View File

@@ -19,18 +19,41 @@ import {
MODEL_MAP,
getDebugFlag,
getBaseUrlForRole,
isApiKeySet
isApiKeySet,
getOllamaBaseURL,
getAzureBaseURL,
getVertexProjectId,
getVertexLocation
} from './config-manager.js';
import { log, findProjectRoot, resolveEnvVariable } from './utils.js';
import * as anthropic from '../../src/ai-providers/anthropic.js';
import * as perplexity from '../../src/ai-providers/perplexity.js';
import * as google from '../../src/ai-providers/google.js';
import * as openai from '../../src/ai-providers/openai.js';
import * as xai from '../../src/ai-providers/xai.js';
import * as openrouter from '../../src/ai-providers/openrouter.js';
import * as ollama from '../../src/ai-providers/ollama.js';
// TODO: Import other provider modules when implemented (ollama, etc.)
// Import provider classes
import {
AnthropicAIProvider,
PerplexityAIProvider,
GoogleAIProvider,
OpenAIProvider,
XAIProvider,
OpenRouterAIProvider,
OllamaAIProvider,
BedrockAIProvider,
AzureProvider,
VertexAIProvider
} from '../../src/ai-providers/index.js';
// Create provider instances
const PROVIDERS = {
anthropic: new AnthropicAIProvider(),
perplexity: new PerplexityAIProvider(),
google: new GoogleAIProvider(),
openai: new OpenAIProvider(),
xai: new XAIProvider(),
openrouter: new OpenRouterAIProvider(),
ollama: new OllamaAIProvider(),
bedrock: new BedrockAIProvider(),
azure: new AzureProvider(),
vertex: new VertexAIProvider()
};
// Helper function to get cost for a specific model
function _getCostForModel(providerName, modelId) {
@@ -62,51 +85,6 @@ function _getCostForModel(providerName, modelId) {
};
}
// --- Provider Function Map ---
// Maps provider names (lowercase) to their respective service functions
const PROVIDER_FUNCTIONS = {
anthropic: {
generateText: anthropic.generateAnthropicText,
streamText: anthropic.streamAnthropicText,
generateObject: anthropic.generateAnthropicObject
},
perplexity: {
generateText: perplexity.generatePerplexityText,
streamText: perplexity.streamPerplexityText,
generateObject: perplexity.generatePerplexityObject
},
google: {
// Add Google entry
generateText: google.generateGoogleText,
streamText: google.streamGoogleText,
generateObject: google.generateGoogleObject
},
openai: {
// ADD: OpenAI entry
generateText: openai.generateOpenAIText,
streamText: openai.streamOpenAIText,
generateObject: openai.generateOpenAIObject
},
xai: {
// ADD: xAI entry
generateText: xai.generateXaiText,
streamText: xai.streamXaiText,
generateObject: xai.generateXaiObject // Note: Object generation might be unsupported
},
openrouter: {
// ADD: OpenRouter entry
generateText: openrouter.generateOpenRouterText,
streamText: openrouter.streamOpenRouterText,
generateObject: openrouter.generateOpenRouterObject
},
ollama: {
generateText: ollama.generateOllamaText,
streamText: ollama.streamOllamaText,
generateObject: ollama.generateOllamaObject
}
// TODO: Add entries for ollama, etc. when implemented
};
// --- Configuration for Retries ---
const MAX_RETRIES = 2;
const INITIAL_RETRY_DELAY_MS = 1000;
@@ -191,7 +169,9 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
azure: 'AZURE_OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY',
xai: 'XAI_API_KEY',
ollama: 'OLLAMA_API_KEY'
ollama: 'OLLAMA_API_KEY',
bedrock: 'AWS_ACCESS_KEY_ID',
vertex: 'GOOGLE_API_KEY'
};
const envVarName = keyMap[providerName];
@@ -203,12 +183,11 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
const apiKey = resolveEnvVariable(envVarName, session, projectRoot);
// Special handling for Ollama - API key is optional
if (providerName === 'ollama') {
// Special handling for providers that can use alternative auth
if (providerName === 'ollama' || providerName === 'bedrock') {
return apiKey || null;
}
// For all other providers, API key is required
if (!apiKey) {
throw new Error(
`Required API key ${envVarName} for provider '${providerName}' is not set in environment, session, or .env file.`
@@ -229,14 +208,15 @@ function _resolveApiKey(providerName, session, projectRoot = null) {
* @throws {Error} If the call fails after all retries.
*/
async function _attemptProviderCallWithRetries(
providerApiFn,
provider,
serviceType,
callParams,
providerName,
modelId,
attemptRole
) {
let retries = 0;
const fnName = providerApiFn.name;
const fnName = serviceType;
while (retries <= MAX_RETRIES) {
try {
@@ -247,8 +227,8 @@ async function _attemptProviderCallWithRetries(
);
}
// Call the specific provider function directly
const result = await providerApiFn(callParams);
// Call the appropriate method on the provider instance
const result = await provider[serviceType](callParams);
if (getDebugFlag()) {
log(
@@ -350,9 +330,8 @@ async function _unifiedServiceRunner(serviceType, params) {
modelId,
apiKey,
roleParams,
providerFnSet,
providerApiFn,
baseUrl,
provider,
baseURL,
providerResponse,
telemetryData = null;
@@ -391,7 +370,20 @@ async function _unifiedServiceRunner(serviceType, params) {
continue;
}
// Check if API key is set for the current provider and role (excluding 'ollama')
// Get provider instance
provider = PROVIDERS[providerName?.toLowerCase()];
if (!provider) {
log(
'warn',
`Skipping role '${currentRole}': Provider '${providerName}' not supported.`
);
lastError =
lastError ||
new Error(`Unsupported provider configured: ${providerName}`);
continue;
}
// Check API key if needed
if (providerName?.toLowerCase() !== 'ollama') {
if (!isApiKeySet(providerName, session, effectiveProjectRoot)) {
log(
@@ -407,40 +399,70 @@ async function _unifiedServiceRunner(serviceType, params) {
}
}
// Get base URL if configured (optional for most providers)
baseURL = getBaseUrlForRole(currentRole, effectiveProjectRoot);
// For Azure, use the global Azure base URL if role-specific URL is not configured
if (providerName?.toLowerCase() === 'azure' && !baseURL) {
baseURL = getAzureBaseURL(effectiveProjectRoot);
log('debug', `Using global Azure base URL: ${baseURL}`);
} else if (providerName?.toLowerCase() === 'ollama' && !baseURL) {
// For Ollama, use the global Ollama base URL if role-specific URL is not configured
baseURL = getOllamaBaseURL(effectiveProjectRoot);
log('debug', `Using global Ollama base URL: ${baseURL}`);
}
// Get AI parameters for the current role
roleParams = getParametersForRole(currentRole, effectiveProjectRoot);
baseUrl = getBaseUrlForRole(currentRole, effectiveProjectRoot);
providerFnSet = PROVIDER_FUNCTIONS[providerName?.toLowerCase()];
if (!providerFnSet) {
log(
'warn',
`Skipping role '${currentRole}': Provider '${providerName}' not supported or map entry missing.`
);
lastError =
lastError ||
new Error(`Unsupported provider configured: ${providerName}`);
continue;
}
providerApiFn = providerFnSet[serviceType];
if (typeof providerApiFn !== 'function') {
log(
'warn',
`Skipping role '${currentRole}': Service type '${serviceType}' not implemented for provider '${providerName}'.`
);
lastError =
lastError ||
new Error(
`Service '${serviceType}' not implemented for provider ${providerName}`
);
continue;
}
apiKey = _resolveApiKey(
providerName?.toLowerCase(),
session,
effectiveProjectRoot
);
// Prepare provider-specific configuration
let providerSpecificParams = {};
// Handle Vertex AI specific configuration
if (providerName?.toLowerCase() === 'vertex') {
// Get Vertex project ID and location
const projectId =
getVertexProjectId(effectiveProjectRoot) ||
resolveEnvVariable(
'VERTEX_PROJECT_ID',
session,
effectiveProjectRoot
);
const location =
getVertexLocation(effectiveProjectRoot) ||
resolveEnvVariable(
'VERTEX_LOCATION',
session,
effectiveProjectRoot
) ||
'us-central1';
// Get credentials path if available
const credentialsPath = resolveEnvVariable(
'GOOGLE_APPLICATION_CREDENTIALS',
session,
effectiveProjectRoot
);
// Add Vertex-specific parameters
providerSpecificParams = {
projectId,
location,
...(credentialsPath && { credentials: { credentialsFromEnv: true } })
};
log(
'debug',
`Using Vertex AI configuration: Project ID=${projectId}, Location=${location}`
);
}
const messages = [];
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
@@ -476,13 +498,15 @@ async function _unifiedServiceRunner(serviceType, params) {
maxTokens: roleParams.maxTokens,
temperature: roleParams.temperature,
messages,
baseUrl,
...(baseURL && { baseURL }),
...(serviceType === 'generateObject' && { schema, objectName }),
...providerSpecificParams,
...restApiParams
};
providerResponse = await _attemptProviderCallWithRetries(
providerApiFn,
provider,
serviceType,
callParams,
providerName,
modelId,