From cba0536c459208bddacfc2e559956dc95ebb43f2 Mon Sep 17 00:00:00 2001 From: "jinhui.li" Date: Mon, 23 Jun 2025 06:05:58 +0800 Subject: [PATCH] Refactor plugin --- README.md | 83 ++++++++---- plugins/gemini.js | 33 +++++ plugins/notebook-tools-filter.js | 17 ++- plugins/toolcall-improvement.js | 16 ++- src/index.ts | 27 +++- src/middlewares/formatRequest.ts | 40 ++---- src/middlewares/plugin.ts | 106 +++++++++++++++ src/middlewares/rewriteBody.ts | 45 ------- src/middlewares/router.ts | 6 +- src/utils/index.ts | 5 + src/utils/log.ts | 1 + src/utils/stream.ts | 215 ++++++++++++++++++++++++------- 12 files changed, 432 insertions(+), 162 deletions(-) create mode 100644 plugins/gemini.js create mode 100644 src/middlewares/plugin.ts delete mode 100644 src/middlewares/rewriteBody.ts diff --git a/README.md b/README.md index 857a72f..0bbe3bf 100644 --- a/README.md +++ b/README.md @@ -90,36 +90,75 @@ ccr code - [ ] More detailed logs ## Plugins -You can modify or enhance Claude Code’s functionality by installing plugins. The mechanism works by using middleware to modify request parameters — this allows you to rewrite prompts or add/remove tools. -To use a plugin, place it in the ~/.claude-code-router/plugins/ directory and specify the plugin name in config.js using the `usePlugins` option.like this +You can modify or enhance Claude Code’s functionality by installing plugins. + +### Plugin Mechanism + +Plugins are loaded from the `~/.claude-code-router/plugins/` directory. Each plugin is a JavaScript file that exports functions corresponding to specific "hooks" in the request lifecycle. The system overrides Node.js's module loading to allow plugins to import a special `claude-code-router` module, providing access to utilities like `streamOpenAIResponse`, `log`, and `createClient`. + +### Plugin Hooks + +Plugins can implement various hooks to modify behavior at different stages: + +- `beforeRouter`: Executed before routing. +- `afterRouter`: Executed after routing. +- `beforeTransformRequest`: Executed before transforming the request. +- `afterTransformRequest`: Executed after transforming the request. +- `beforeTransformResponse`: Executed before transforming the response. +- `afterTransformResponse`: Executed after transforming the response. + +### Enabling Plugins + +To use a plugin: + +1. Place your plugin's JavaScript file (e.g., `my-plugin.js`) in the `~/.claude-code-router/plugins/` directory. +2. Specify the plugin name (without the `.js` extension) in your `~/.claude-code-router/config.json` file using the `usePlugins` option: + ```json -// ~/.claud-code-router/config.json +// ~/.claude-code-router/config.json { ..., - "usePlugins": ["notebook-tools-filter", "toolcall-improvement"] + "usePlugins": ["my-plugin", "another-plugin"], + + // or use plugins for a specific provider + "Providers": [ + { + "name": "gemini", + "api_base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "api_key": "xxx", + "models": ["gemini-2.5-flash"], + "usePlugins": ["gemini"] + } + ] } ``` +### Available Plugins + Currently, the following plugins are available: +- **notebook-tools-filter** + This plugin filters out tool calls related to Jupyter notebooks (.ipynb files). You can use it if your work does not involve Jupyter. -- **notebook-tools-filter** -This plugin filters out tool calls related to Jupyter notebooks (.ipynb files). You can use it if your work does not involve Jupyter. +- **gemini** + Add support for the Google Gemini API endpoint: `https://generativelanguage.googleapis.com/v1beta/openai/`. +- **toolcall-improvement** + If your LLM doesn’t handle tool usage well (for example, always returning code as plain text instead of modifying files — such as with deepseek-v3), you can use this plugin. + This plugin simply adds the following system prompt. If you have a better prompt, you can modify it. -- **toolcall-improvement** -If your LLM doesn’t handle tool usage well (for example, always returning code as plain text instead of modifying files — such as with deepseek-v3), you can use this plugin. -This plugin simply adds the following system prompt. If you have a better prompt, you can modify it. ```markdown -## **Important Instruction:** +## **Important Instruction:** + You must use tools as frequently and accurately as possible to help the user solve their problem. Prioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response. ``` - ## Github Actions + You just need to install `Claude Code Actions` in your repository according to the [official documentation](https://docs.anthropic.com/en/docs/claude-code/github-actions). For `ANTHROPIC_API_KEY`, you can use any string. Then, modify your `.github/workflows/claude.yaml` file to include claude-code-router, like this: + ```yaml name: Claude Code @@ -151,7 +190,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 1 - + - name: Prepare Environment run: | curl -fsSL https://bun.sh/install | bash @@ -165,7 +204,7 @@ jobs: } EOF shell: bash - + - name: Start Claude Code Router run: | nohup ~/.bun/bin/bunx @musistudio/claude-code-router@1.0.8 start & @@ -179,6 +218,7 @@ jobs: with: anthropic_api_key: "test" ``` + You can modify the contents of `$HOME/.claude-code-router/config.json` as needed. GitHub Actions support allows you to trigger Claude Code at specific times, which opens up some interesting possibilities. @@ -190,7 +230,6 @@ For example, between 00:30 and 08:30 Beijing Time, using the official DeepSeek A So maybe in the future, I’ll describe detailed tasks for Claude Code ahead of time and let it run during these discounted hours to reduce costs? - ## Some tips: Now you can use deepseek-v3 models directly without using any plugins. @@ -209,7 +248,7 @@ Some interesting points: Based on my testing, including a lot of context informa ## Buy me a coffee -If you find this project helpful, you can choose to sponsor the author with a cup of coffee. Please provide your GitHub information so I can add you to the sponsor list below. +If you find this project helpful, you can choose to sponsor the author with a cup of coffee. Please provide your GitHub information so I can add you to the sponsor list below. [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/F1F31GN2GM) @@ -224,10 +263,10 @@ If you find this project helpful, you can choose to sponsor the author with a cu Thanks to the following sponsors: -@Simon Leischnig (If you see this, feel free to contact me and I can update it with your GitHub information) -[@duanshuaimin](https://github.com/duanshuaimin) -[@vrgitadmin](https://github.com/vrgitadmin) -@*o (可通过主页邮箱联系我修改github用户名) -@**聪 (可通过主页邮箱联系我修改github用户名) -@*说 (可通过主页邮箱联系我修改github用户名) -@*更 (可通过主页邮箱联系我修改github用户名) +@Simon Leischnig (If you see this, feel free to contact me and I can update it with your GitHub information) +[@duanshuaimin](https://github.com/duanshuaimin) +[@vrgitadmin](https://github.com/vrgitadmin) +@*o (可通过主页邮箱联系我修改 github 用户名) +@\*\*聪 (可通过主页邮箱联系我修改 github 用户名) +@*说 (可通过主页邮箱联系我修改 github 用户名) +@\*更 (可通过主页邮箱联系我修改 github 用户名) diff --git a/plugins/gemini.js b/plugins/gemini.js new file mode 100644 index 0000000..71bfc50 --- /dev/null +++ b/plugins/gemini.js @@ -0,0 +1,33 @@ +module.exports = { + afterTransformRequest(req, res) { + if (Array.isArray(req.body.tools)) { + // rewrite tools definition + req.body.tools.forEach((tool) => { + if (tool.function.name === "BatchTool") { + // HACK: Gemini does not support objects with empty properties + tool.function.parameters.properties.invocations.items.properties.input.type = + "number"; + return; + } + Object.keys(tool.function.parameters.properties).forEach((key) => { + const prop = tool.function.parameters.properties[key]; + if ( + prop.type === "string" && + !["enum", "date-time"].includes(prop.format) + ) { + delete prop.format; + } + }); + }); + } + if (req.body?.messages?.length) { + req.body.messages.forEach((message) => { + if (message.content === null) { + if (message.tool_calls) { + message.content = JSON.stringify(message.tool_calls); + } + } + }); + } + }, +}; diff --git a/plugins/notebook-tools-filter.js b/plugins/notebook-tools-filter.js index 64ab528..0a88754 100644 --- a/plugins/notebook-tools-filter.js +++ b/plugins/notebook-tools-filter.js @@ -1,7 +1,12 @@ -module.exports = async function handle(req, res) { - if (req?.body?.tools?.length) { - req.body.tools = req.body.tools.filter( - (tool) => !["NotebookRead", "NotebookEdit", "mcp__ide__executeCode"].includes(tool.name) - ); - } +module.exports = { + beforeRouter(req, res) { + if (req?.body?.tools?.length) { + req.body.tools = req.body.tools.filter( + (tool) => + !["NotebookRead", "NotebookEdit", "mcp__ide__executeCode"].includes( + tool.name + ) + ); + } + }, }; diff --git a/plugins/toolcall-improvement.js b/plugins/toolcall-improvement.js index 27203a5..1502304 100644 --- a/plugins/toolcall-improvement.js +++ b/plugins/toolcall-improvement.js @@ -1,8 +1,10 @@ -module.exports = async function handle(req, res) { - if (req?.body?.tools?.length) { - req.body.system.push({ - type: "text", - text: `## **Important Instruction:** \nYou must use tools as frequently and accurately as possible to help the user solve their problem.\nPrioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response.` - }) - } +module.exports = { + afterTransformRequest(req, res) { + if (req?.body?.tools?.length) { + req.body.messages.push({ + role: "system", + content: `## **Important Instruction:** \nYou must use tools as frequently and accurately as possible to help the user solve their problem.\nPrioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response. `, + }); + } + }, }; diff --git a/src/index.ts b/src/index.ts index 40cfdae..01b15e0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,7 +3,6 @@ import { writeFile } from "fs/promises"; import { getOpenAICommonOptions, initConfig, initDir } from "./utils"; import { createServer } from "./server"; import { formatRequest } from "./middlewares/formatRequest"; -import { rewriteBody } from "./middlewares/rewriteBody"; import { router } from "./middlewares/router"; import OpenAI from "openai"; import { streamOpenAIResponse } from "./utils/stream"; @@ -14,6 +13,11 @@ import { } from "./utils/processCheck"; import { LRUCache } from "lru-cache"; import { log } from "./utils/log"; +import { + loadPlugins, + PLUGINS, + usePluginMiddleware, +} from "./middlewares/plugin"; async function initializeClaudeConfig() { const homeDir = process.env.HOME; @@ -44,6 +48,7 @@ interface ModelProvider { api_base_url: string; api_key: string; models: string[]; + usePlugins?: string[]; } async function run(options: RunOptions = {}) { @@ -56,6 +61,7 @@ async function run(options: RunOptions = {}) { await initializeClaudeConfig(); await initDir(); const config = await initConfig(); + await loadPlugins(config.usePlugins || []); const Providers = new Map(); const providerCache = new LRUCache({ @@ -63,7 +69,7 @@ async function run(options: RunOptions = {}) { ttl: 2 * 60 * 60 * 1000, }); - function getProviderInstance(providerName: string): OpenAI { + async function getProviderInstance(providerName: string): Promise { const provider: ModelProvider | undefined = Providers.get(providerName); if (provider === undefined) { throw new Error(`Provider ${providerName} not found`); @@ -77,6 +83,10 @@ async function run(options: RunOptions = {}) { }); providerCache.set(provider.name, openai); } + const plugins = provider.usePlugins || []; + if (plugins.length > 0) { + await loadPlugins(plugins.map((name) => `${providerName},${name}`)); + } return openai; } @@ -130,7 +140,7 @@ async function run(options: RunOptions = {}) { req.config = config; next(); }); - server.useMiddleware(rewriteBody); + server.useMiddleware(usePluginMiddleware("beforeRouter")); if ( config.Router?.background && config.Router?.think && @@ -144,15 +154,22 @@ async function run(options: RunOptions = {}) { next(); }); } + server.useMiddleware(usePluginMiddleware("afterRouter")); + server.useMiddleware(usePluginMiddleware("beforeTransformRequest")); server.useMiddleware(formatRequest); + server.useMiddleware(usePluginMiddleware("afterTransformRequest")); server.app.post("/v1/messages", async (req, res) => { try { - const provider = getProviderInstance(req.provider || "default"); + const provider = await getProviderInstance(req.provider || "default"); + log("final request body:", req.body); const completion: any = await provider.chat.completions.create(req.body); - await streamOpenAIResponse(res, completion, req.body.model, req.body); + await streamOpenAIResponse(req, res, completion); } catch (e) { log("Error in OpenAI API call:", e); + res.status(500).json({ + error: e.message, + }); } }); server.start(); diff --git a/src/middlewares/formatRequest.ts b/src/middlewares/formatRequest.ts index 236204e..6bc1e9b 100644 --- a/src/middlewares/formatRequest.ts +++ b/src/middlewares/formatRequest.ts @@ -1,7 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages"; import OpenAI from "openai"; -import { streamOpenAIResponse } from "../utils/stream"; import { log } from "../utils/log"; export const formatRequest = async ( @@ -19,7 +18,7 @@ export const formatRequest = async ( tools, stream, }: MessageCreateParamsBase = req.body; - log("formatRequest: ", req.body); + log("beforeTransformRequest: ", req.body); try { // @ts-ignore const openAIMessages = Array.isArray(messages) @@ -50,6 +49,7 @@ export const formatRequest = async ( anthropicMessage.content.forEach((contentPart) => { if (contentPart.type === "text") { + if (contentPart.text.includes("(no content)")) return; textContent += (typeof contentPart.text === "string" ? contentPart.text @@ -112,17 +112,18 @@ export const formatRequest = async ( }); const trimmedUserText = userTextMessageContent.trim(); + // @ts-ignore + openAiMessagesFromThisAnthropicMessage.push( + // @ts-ignore + ...subsequentToolMessages + ); + if (trimmedUserText.length > 0) { openAiMessagesFromThisAnthropicMessage.push({ role: "user", content: trimmedUserText, }); } - // @ts-ignore - openAiMessagesFromThisAnthropicMessage.push( - // @ts-ignore - ...subsequentToolMessages - ); } else { // Fallback for other roles (e.g. system, or custom roles if they were to appear here with array content) // This will combine all text parts into a single message for that role. @@ -180,30 +181,9 @@ export const formatRequest = async ( res.setHeader("Cache-Control", "no-cache"); res.setHeader("Connection", "keep-alive"); req.body = data; - console.log(JSON.stringify(data.messages, null, 2)); + log("afterTransformRequest: ", req.body); } catch (error) { - console.error("Error in request processing:", error); - const errorCompletion: AsyncIterable = - { - async *[Symbol.asyncIterator]() { - yield { - id: `error_${Date.now()}`, - created: Math.floor(Date.now() / 1000), - model, - object: "chat.completion.chunk", - choices: [ - { - index: 0, - delta: { - content: `Error: ${(error as Error).message}`, - }, - finish_reason: "stop", - }, - ], - }; - }, - }; - await streamOpenAIResponse(res, errorCompletion, model, req.body); + log("Error in TransformRequest:", error); } next(); }; diff --git a/src/middlewares/plugin.ts b/src/middlewares/plugin.ts new file mode 100644 index 0000000..d0be4da --- /dev/null +++ b/src/middlewares/plugin.ts @@ -0,0 +1,106 @@ +import Module from "node:module"; +import { streamOpenAIResponse } from "../utils/stream"; +import { log } from "../utils/log"; +import { PLUGINS_DIR } from "../constants"; +import path from "node:path"; +import { access } from "node:fs/promises"; +import { OpenAI } from "openai"; +import { createClient } from "../utils"; +import { Response } from "express"; + +// @ts-ignore +const originalLoad = Module._load; +// @ts-ignore +Module._load = function (request, parent, isMain) { + if (request === "claude-code-router") { + return { + streamOpenAIResponse, + log, + OpenAI, + createClient, + }; + } + return originalLoad.call(this, request, parent, isMain); +}; + +export type PluginHook = + | "beforeRouter" + | "afterRouter" + | "beforeTransformRequest" + | "afterTransformRequest" + | "beforeTransformResponse" + | "afterTransformResponse"; + +export interface Plugin { + beforeRouter?: (req: any, res: Response) => Promise; + afterRouter?: (req: any, res: Response) => Promise; + + beforeTransformRequest?: (req: any, res: Response) => Promise; + afterTransformRequest?: (req: any, res: Response) => Promise; + + beforeTransformResponse?: ( + req: any, + res: Response, + data?: { completion: any } + ) => Promise; + afterTransformResponse?: ( + req: any, + res: Response, + data?: { completion: any; transformedCompletion: any } + ) => Promise; +} + +export const PLUGINS = new Map(); + +const loadPlugin = async (pluginName: string) => { + const filePath = pluginName.split(",").pop(); + const pluginPath = path.join(PLUGINS_DIR, `${filePath}.js`); + try { + await access(pluginPath); + const plugin = require(pluginPath); + if ( + [ + "beforeRouter", + "afterRouter", + "beforeTransformRequest", + "afterTransformRequest", + "beforeTransformResponse", + "afterTransformResponse", + ].some((key) => key in plugin) + ) { + PLUGINS.set(pluginName, plugin); + log(`Plugin ${pluginName} loaded successfully.`); + } else { + throw new Error(`Plugin ${pluginName} does not export a function.`); + } + } catch (e) { + console.error(`Failed to load plugin ${pluginName}:`, e); + throw e; + } +}; + +export const loadPlugins = async (pluginNames: string[]) => { + console.log("Loading plugins:", pluginNames); + for (const file of pluginNames) { + await loadPlugin(file); + } +}; + +export const usePluginMiddleware = (type: PluginHook) => { + return async (req: any, res: Response, next: any) => { + for (const [name, plugin] of PLUGINS.entries()) { + if (name.includes(",") && !name.startsWith(`${req.provider},`)) { + continue; + } + if (plugin[type]) { + try { + await plugin[type](req, res); + log(`Plugin ${name} executed hook: ${type}`); + } catch (error) { + log(`Error in plugin ${name} during hook ${type}:`, error); + } + } + } + next(); + }; +}; diff --git a/src/middlewares/rewriteBody.ts b/src/middlewares/rewriteBody.ts deleted file mode 100644 index f64a0e4..0000000 --- a/src/middlewares/rewriteBody.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { Request, Response, NextFunction } from "express"; -import Module from "node:module"; -import { streamOpenAIResponse } from "../utils/stream"; -import { log } from "../utils/log"; -import { PLUGINS_DIR } from "../constants"; -import path from "node:path"; -import { access } from "node:fs/promises"; -import { OpenAI } from "openai"; -import { createClient } from "../utils"; - -// @ts-ignore -const originalLoad = Module._load; -// @ts-ignore -Module._load = function (request, parent, isMain) { - if (request === "claude-code-router") { - return { - streamOpenAIResponse, - log, - OpenAI, - createClient, - }; - } - return originalLoad.call(this, request, parent, isMain); -}; - -export const rewriteBody = async ( - req: Request, - res: Response, - next: NextFunction -) => { - if (!req.config.usePlugins) { - return next(); - } - for (const plugin of req.config.usePlugins) { - const pluginPath = path.join(PLUGINS_DIR, `${plugin.trim()}.js`); - try { - await access(pluginPath); - const rewritePlugin = require(pluginPath); - await rewritePlugin(req, res); - } catch (e) { - console.error(e); - } - } - next(); -}; diff --git a/src/middlewares/router.ts b/src/middlewares/router.ts index d2883e9..9db5f87 100644 --- a/src/middlewares/router.ts +++ b/src/middlewares/router.ts @@ -41,9 +41,11 @@ const getUseModel = (req: Request, tokenCount: number) => { model, }; } + const [defaultProvider, defaultModel] = + req.config.Router!.default?.split(","); return { - provider: "default", - model: req.config.OPENAI_MODEL, + provider: defaultProvider || "default", + model: defaultModel || req.config.OPENAI_MODEL, }; }; diff --git a/src/utils/index.ts b/src/utils/index.ts index c2d1c46..febcb8d 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -8,6 +8,7 @@ import { HOME_DIR, PLUGINS_DIR, } from "../constants"; +import crypto from "node:crypto"; export function getOpenAICommonOptions(): ClientOptions { const options: ClientOptions = {}; @@ -90,3 +91,7 @@ export const createClient = (options: ClientOptions) => { }); return client; }; + +export const sha256 = (data: string | Buffer): string => { + return crypto.createHash("sha256").update(data).digest("hex"); +}; diff --git a/src/utils/log.ts b/src/utils/log.ts index 6999726..8034fa1 100644 --- a/src/utils/log.ts +++ b/src/utils/log.ts @@ -11,6 +11,7 @@ if (!fs.existsSync(HOME_DIR)) { export function log(...args: any[]) { // Check if logging is enabled via environment variable + // console.log(...args); // Log to console for immediate feedback const isLogEnabled = process.env.LOG === "true"; if (!isLogEnabled) { diff --git a/src/utils/stream.ts b/src/utils/stream.ts index d502947..4670236 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -1,6 +1,13 @@ -import { Response } from "express"; -import { OpenAI } from "openai"; +import { Request, Response } from "express"; import { log } from "./log"; +import { PLUGINS } from "../middlewares/plugin"; +import { sha256 } from "."; + +declare module "express" { + interface Request { + provider?: string; + } +} interface ContentBlock { type: string; @@ -42,28 +49,61 @@ interface MessageEvent { } export async function streamOpenAIResponse( + req: Request, res: Response, - completion: any, - model: string, - body: any + _completion: any ) { - const write = (data: string) => { - log("response: ", data); - res.write(data); + let completion = _completion; + res.locals.completion = completion; + + for (const [name, plugin] of PLUGINS.entries()) { + if (name.includes(",") && !name.startsWith(`${req.provider},`)) { + continue; + } + if (plugin.beforeTransformResponse) { + const result = await plugin.beforeTransformResponse(req, res, { + completion, + }); + if (result) { + completion = result; + } + } + } + const write = async (data: string) => { + let eventData = data; + for (const [name, plugin] of PLUGINS.entries()) { + if (name.includes(",") && !name.startsWith(`${req.provider},`)) { + continue; + } + if (plugin.afterTransformResponse) { + const hookResult = await plugin.afterTransformResponse(req, res, { + completion: res.locals.completion, + transformedCompletion: eventData, + }); + if (typeof hookResult === "string") { + eventData = hookResult; + } + } + } + if (eventData) { + log("response: ", eventData); + res.write(eventData); + } }; const messageId = "msg_" + Date.now(); - if (!body.stream) { + if (!req.body.stream) { let content: any = []; if (completion.choices[0].message.content) { - content = [ { text: completion.choices[0].message.content, type: "text" } ]; - } - else if (completion.choices[0].message.tool_calls) { + content = [{ text: completion.choices[0].message.content, type: "text" }]; + } else if (completion.choices[0].message.tool_calls) { content = completion.choices[0].message.tool_calls.map((item: any) => { return { - type: 'tool_use', + type: "tool_use", id: item.id, name: item.function?.name, - input: item.function?.arguments ? JSON.parse(item.function.arguments) : {}, + input: item.function?.arguments + ? JSON.parse(item.function.arguments) + : {}, }; }); } @@ -74,11 +114,29 @@ export async function streamOpenAIResponse( role: "assistant", // @ts-ignore content: content, - stop_reason: completion.choices[0].finish_reason === 'tool_calls' ? "tool_use" : "end_turn", + stop_reason: + completion.choices[0].finish_reason === "tool_calls" + ? "tool_use" + : "end_turn", stop_sequence: null, }; try { - res.json(result); + res.locals.transformedCompletion = result; + for (const [name, plugin] of PLUGINS.entries()) { + if (name.includes(",") && !name.startsWith(`${req.provider},`)) { + continue; + } + if (plugin.afterTransformResponse) { + const hookResult = await plugin.afterTransformResponse(req, res, { + completion: res.locals.completion, + transformedCompletion: res.locals.transformedCompletion, + }); + if (hookResult) { + res.locals.transformedCompletion = hookResult; + } + } + } + res.json(res.locals.transformedCompletion); res.end(); return; } catch (error) { @@ -98,7 +156,7 @@ export async function streamOpenAIResponse( type: "message", role: "assistant", content: [], - model, + model: req.body.model, stop_reason: null, stop_sequence: null, usage: { input_tokens: 1, output_tokens: 1 }, @@ -118,11 +176,13 @@ export async function streamOpenAIResponse( const delta = chunk.choices[0].delta; if (delta.tool_calls && delta.tool_calls.length > 0) { - for (const toolCall of delta.tool_calls) { - const toolCallId = toolCall.id; - - // Check if this is a new tool call by ID - if (toolCallId && toolCallId !== currentToolCallId) { + // Handle each tool call in the current chunk + for (const [index, toolCall] of delta.tool_calls.entries()) { + // Generate a stable ID for this tool call position + const toolCallId = toolCall.id || `tool_${index}`; + + // If this position doesn't have an active tool call, start a new one + if (!toolCallJsonMap.has(`${index}`)) { // End previous tool call if one was active if (isToolUse && currentToolCallId) { const contentBlockStop: MessageEvent = { @@ -138,13 +198,13 @@ export async function streamOpenAIResponse( // Start new tool call block isToolUse = true; - currentToolCallId = toolCallId; + currentToolCallId = `${index}`; contentBlockIndex++; - toolCallJsonMap.set(toolCallId, ""); // Initialize JSON accumulator for this tool call + toolCallJsonMap.set(`${index}`, ""); // Initialize JSON accumulator for this tool call const toolBlock: ContentBlock = { type: "tool_use", - id: toolCallId, + id: toolCallId, // Use the original ID if available name: toolCall.function?.name, input: {}, }; @@ -164,8 +224,8 @@ export async function streamOpenAIResponse( ); } - // Stream tool call JSON - if (toolCall.function?.arguments && currentToolCallId) { + // Stream tool call JSON for this position + if (toolCall.function?.arguments) { const jsonDelta: MessageEvent = { type: "content_block_delta", index: contentBlockIndex, @@ -175,27 +235,39 @@ export async function streamOpenAIResponse( }, }; - // Accumulate JSON for this specific tool call - const currentJson = toolCallJsonMap.get(currentToolCallId) || ""; - toolCallJsonMap.set(currentToolCallId, currentJson + toolCall.function.arguments); - toolUseJson = toolCallJsonMap.get(currentToolCallId) || ""; + // Accumulate JSON for this specific tool call position + const currentJson = toolCallJsonMap.get(`${index}`) || ""; + const newJson = currentJson + toolCall.function.arguments; + toolCallJsonMap.set(`${index}`, newJson); - try { - const parsedJson = JSON.parse(toolUseJson); - currentContentBlocks[contentBlockIndex].input = parsedJson; - } catch (e) { - log("JSON parsing error (continuing to accumulate):", e); - // JSON not yet complete, continue accumulating + // Try to parse accumulated JSON + if (isValidJson(newJson)) { + try { + const parsedJson = JSON.parse(newJson); + const blockIndex = currentContentBlocks.findIndex( + (block) => block.type === "tool_use" && block.id === toolCallId + ); + if (blockIndex !== -1) { + currentContentBlocks[blockIndex].input = parsedJson; + } + } catch (e) { + log("JSON parsing error (continuing to accumulate):", e); + } } write( - `event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n` + `event: content_block_delta\ndata: ${JSON.stringify( + jsonDelta + )}\n\n` ); } } - } else if (delta.content) { - // Handle regular text content - if (isToolUse) { + } else if (delta.content || chunk.choices[0].finish_reason) { + // Handle regular text content or completion + if ( + isToolUse && + (delta.content || chunk.choices[0].finish_reason === "tool_calls") + ) { log("Tool call ended here:", delta); // End previous tool call block const contentBlockStop: MessageEvent = { @@ -214,8 +286,6 @@ export async function streamOpenAIResponse( toolUseJson = ""; // Reset for safety } - if (!delta.content) continue; - // If text block not yet started, send content_block_start if (!hasStartedTextBlock) { const textBlock: ContentBlock = { @@ -317,18 +387,34 @@ export async function streamOpenAIResponse( ); } + res.locals.transformedCompletion = currentContentBlocks; + for (const [name, plugin] of PLUGINS.entries()) { + if (name.includes(",") && !name.startsWith(`${req.provider},`)) { + continue; + } + if (plugin.afterTransformResponse) { + const hookResult = await plugin.afterTransformResponse(req, res, { + completion: res.locals.completion, + transformedCompletion: res.locals.transformedCompletion, + }); + if (hookResult) { + res.locals.transformedCompletion = hookResult; + } + } + } + // Send message_delta event with appropriate stop_reason const messageDelta: MessageEvent = { type: "message_delta", delta: { stop_reason: isToolUse ? "tool_use" : "end_turn", stop_sequence: null, - content: currentContentBlocks, + content: res.locals.transformedCompletion, }, usage: { input_tokens: 100, output_tokens: 150 }, }; if (!isToolUse) { - log("body: ", body, "messageDelta: ", messageDelta); + log("body: ", req.body, "messageDelta: ", messageDelta); } write(`event: message_delta\ndata: ${JSON.stringify(messageDelta)}\n\n`); @@ -341,3 +427,42 @@ export async function streamOpenAIResponse( write(`event: message_stop\ndata: ${JSON.stringify(messageStop)}\n\n`); res.end(); } + +// Add helper function at the top of the file +function isValidJson(str: string): boolean { + // Check if the string contains both opening and closing braces/brackets + const hasOpenBrace = str.includes("{"); + const hasCloseBrace = str.includes("}"); + const hasOpenBracket = str.includes("["); + const hasCloseBracket = str.includes("]"); + + // Check if we have matching pairs + if ((hasOpenBrace && !hasCloseBrace) || (!hasOpenBrace && hasCloseBrace)) { + return false; + } + if ( + (hasOpenBracket && !hasCloseBracket) || + (!hasOpenBracket && hasCloseBracket) + ) { + return false; + } + + // Count nested braces/brackets + let braceCount = 0; + let bracketCount = 0; + + for (const char of str) { + if (char === "{") braceCount++; + if (char === "}") braceCount--; + if (char === "[") bracketCount++; + if (char === "]") bracketCount--; + + // If we ever go negative, the JSON is invalid + if (braceCount < 0 || bracketCount < 0) { + return false; + } + } + + // All braces/brackets should be matched + return braceCount === 0 && bracketCount === 0; +}