refactor: improve ClaudeProvider query execution and message handling

- Enhanced the executeQuery method to better handle conversation history and user messages, ensuring compliance with SDK requirements.
- Introduced a default tools array for allowedTools, simplifying the options setup.
- Updated the getAvailableModels method to use type assertions for model tiers and ensured proper return type with TypeScript's satisfies operator.
- Added error handling during query execution to log and propagate errors effectively.
This commit is contained in:
Cody Seibert
2025-12-13 21:53:55 -05:00
parent 6446dd5d3a
commit c21a298e07

View File

@@ -7,7 +7,10 @@
import { query, type Options } from "@anthropic-ai/claude-agent-sdk"; import { query, type Options } from "@anthropic-ai/claude-agent-sdk";
import { BaseProvider } from "./base-provider.js"; import { BaseProvider } from "./base-provider.js";
import { convertHistoryToMessages, normalizeContentBlocks } from "../lib/conversation-utils.js"; import {
convertHistoryToMessages,
normalizeContentBlocks,
} from "../lib/conversation-utils.js";
import type { import type {
ExecuteOptions, ExecuteOptions,
ProviderMessage, ProviderMessage,
@@ -23,7 +26,9 @@ export class ClaudeProvider extends BaseProvider {
/** /**
* Execute a query using Claude Agent SDK * Execute a query using Claude Agent SDK
*/ */
async *executeQuery(options: ExecuteOptions): AsyncGenerator<ProviderMessage> { async *executeQuery(
options: ExecuteOptions
): AsyncGenerator<ProviderMessage> {
const { const {
prompt, prompt,
model, model,
@@ -36,12 +41,7 @@ export class ClaudeProvider extends BaseProvider {
} = options; } = options;
// Build Claude SDK options // Build Claude SDK options
const sdkOptions: Options = { const defaultTools = [
model,
systemPrompt,
maxTurns,
cwd,
allowedTools: allowedTools || [
"Read", "Read",
"Write", "Write",
"Edit", "Edit",
@@ -50,7 +50,15 @@ export class ClaudeProvider extends BaseProvider {
"Bash", "Bash",
"WebSearch", "WebSearch",
"WebFetch", "WebFetch",
], ];
const toolsToUse = allowedTools || defaultTools;
const sdkOptions: Options = {
model,
systemPrompt,
maxTurns,
cwd,
allowedTools: toolsToUse,
permissionMode: "acceptEdits", permissionMode: "acceptEdits",
sandbox: { sandbox: {
enabled: true, enabled: true,
@@ -60,32 +68,68 @@ export class ClaudeProvider extends BaseProvider {
}; };
// Build prompt payload with conversation history // Build prompt payload with conversation history
let promptPayload: string | AsyncGenerator<any, void, unknown>; let promptPayload: string | AsyncGenerator<any, void, unknown> | Array<any>;
if (conversationHistory && conversationHistory.length > 0) { if (conversationHistory && conversationHistory.length > 0) {
// Multi-turn conversation with history // Multi-turn conversation with history
promptPayload = (async function* () { // Convert history to SDK message format
// Yield history messages using utility // Note: When using async generator, SDK only accepts SDKUserMessage (type: 'user')
// So we filter to only include user messages to avoid SDK errors
const historyMessages = convertHistoryToMessages(conversationHistory); const historyMessages = convertHistoryToMessages(conversationHistory);
const hasAssistantMessages = historyMessages.some(
(msg) => msg.type === "assistant"
);
if (hasAssistantMessages) {
// If we have assistant messages, use async generator but filter to only user messages
// This maintains conversation flow while respecting SDK type constraints
promptPayload = (async function* () {
// Filter to only user messages - SDK async generator only accepts SDKUserMessage
const userHistoryMessages = historyMessages.filter(
(msg) => msg.type === "user"
);
for (const msg of userHistoryMessages) {
yield msg;
}
// Yield current prompt
const normalizedPrompt = normalizeContentBlocks(prompt);
const currentPrompt = {
type: "user" as const,
session_id: "",
message: {
role: "user" as const,
content: normalizedPrompt,
},
parent_tool_use_id: null,
};
yield currentPrompt;
})();
} else {
// Only user messages in history - can use async generator normally
promptPayload = (async function* () {
for (const msg of historyMessages) { for (const msg of historyMessages) {
yield msg; yield msg;
} }
// Yield current prompt // Yield current prompt
yield { const normalizedPrompt = normalizeContentBlocks(prompt);
const currentPrompt = {
type: "user" as const, type: "user" as const,
session_id: "", session_id: "",
message: { message: {
role: "user" as const, role: "user" as const,
content: normalizeContentBlocks(prompt), content: normalizedPrompt,
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
}; };
yield currentPrompt;
})(); })();
}
} else if (Array.isArray(prompt)) { } else if (Array.isArray(prompt)) {
// Multi-part prompt (with images) - no history // Multi-part prompt (with images) - no history
promptPayload = (async function* () { promptPayload = (async function* () {
yield { const multiPartPrompt = {
type: "user" as const, type: "user" as const,
session_id: "", session_id: "",
message: { message: {
@@ -94,6 +138,7 @@ export class ClaudeProvider extends BaseProvider {
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
}; };
yield multiPartPrompt;
})(); })();
} else { } else {
// Simple text prompt - no history // Simple text prompt - no history
@@ -101,12 +146,20 @@ export class ClaudeProvider extends BaseProvider {
} }
// Execute via Claude Agent SDK // Execute via Claude Agent SDK
try {
const stream = query({ prompt: promptPayload, options: sdkOptions }); const stream = query({ prompt: promptPayload, options: sdkOptions });
// Stream messages directly - they're already in the correct format // Stream messages directly - they're already in the correct format
for await (const msg of stream) { for await (const msg of stream) {
yield msg as ProviderMessage; yield msg as ProviderMessage;
} }
} catch (error) {
console.error(
"[ClaudeProvider] executeQuery() error during execution:",
error
);
throw error;
}
} }
/** /**
@@ -114,22 +167,25 @@ export class ClaudeProvider extends BaseProvider {
*/ */
async detectInstallation(): Promise<InstallationStatus> { async detectInstallation(): Promise<InstallationStatus> {
// Claude SDK is always available since it's a dependency // Claude SDK is always available since it's a dependency
const hasApiKey = const hasAnthropicKey = !!process.env.ANTHROPIC_API_KEY;
!!process.env.ANTHROPIC_API_KEY || !!process.env.CLAUDE_CODE_OAUTH_TOKEN; const hasOAuthToken = !!process.env.CLAUDE_CODE_OAUTH_TOKEN;
const hasApiKey = hasAnthropicKey || hasOAuthToken;
return { const status: InstallationStatus = {
installed: true, installed: true,
method: "sdk", method: "sdk",
hasApiKey, hasApiKey,
authenticated: hasApiKey, authenticated: hasApiKey,
}; };
return status;
} }
/** /**
* Get available Claude models * Get available Claude models
*/ */
getAvailableModels(): ModelDefinition[] { getAvailableModels(): ModelDefinition[] {
return [ const models = [
{ {
id: "claude-opus-4-5-20251101", id: "claude-opus-4-5-20251101",
name: "Claude Opus 4.5", name: "Claude Opus 4.5",
@@ -140,7 +196,7 @@ export class ClaudeProvider extends BaseProvider {
maxOutputTokens: 16000, maxOutputTokens: 16000,
supportsVision: true, supportsVision: true,
supportsTools: true, supportsTools: true,
tier: "premium", tier: "premium" as const,
default: true, default: true,
}, },
{ {
@@ -153,7 +209,7 @@ export class ClaudeProvider extends BaseProvider {
maxOutputTokens: 16000, maxOutputTokens: 16000,
supportsVision: true, supportsVision: true,
supportsTools: true, supportsTools: true,
tier: "standard", tier: "standard" as const,
}, },
{ {
id: "claude-3-5-sonnet-20241022", id: "claude-3-5-sonnet-20241022",
@@ -165,7 +221,7 @@ export class ClaudeProvider extends BaseProvider {
maxOutputTokens: 8000, maxOutputTokens: 8000,
supportsVision: true, supportsVision: true,
supportsTools: true, supportsTools: true,
tier: "standard", tier: "standard" as const,
}, },
{ {
id: "claude-3-5-haiku-20241022", id: "claude-3-5-haiku-20241022",
@@ -177,9 +233,10 @@ export class ClaudeProvider extends BaseProvider {
maxOutputTokens: 8000, maxOutputTokens: 8000,
supportsVision: true, supportsVision: true,
supportsTools: true, supportsTools: true,
tier: "basic", tier: "basic" as const,
}, },
]; ] satisfies ModelDefinition[];
return models;
} }
/** /**