Compare commits
1 Commits
main
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cba0536c45 |
57
README.md
57
README.md
@@ -90,36 +90,75 @@ ccr code
|
||||
- [ ] More detailed logs
|
||||
|
||||
## Plugins
|
||||
You can modify or enhance Claude Code’s functionality by installing plugins. The mechanism works by using middleware to modify request parameters — this allows you to rewrite prompts or add/remove tools.
|
||||
|
||||
To use a plugin, place it in the ~/.claude-code-router/plugins/ directory and specify the plugin name in config.js using the `usePlugins` option.like this
|
||||
You can modify or enhance Claude Code’s functionality by installing plugins.
|
||||
|
||||
### Plugin Mechanism
|
||||
|
||||
Plugins are loaded from the `~/.claude-code-router/plugins/` directory. Each plugin is a JavaScript file that exports functions corresponding to specific "hooks" in the request lifecycle. The system overrides Node.js's module loading to allow plugins to import a special `claude-code-router` module, providing access to utilities like `streamOpenAIResponse`, `log`, and `createClient`.
|
||||
|
||||
### Plugin Hooks
|
||||
|
||||
Plugins can implement various hooks to modify behavior at different stages:
|
||||
|
||||
- `beforeRouter`: Executed before routing.
|
||||
- `afterRouter`: Executed after routing.
|
||||
- `beforeTransformRequest`: Executed before transforming the request.
|
||||
- `afterTransformRequest`: Executed after transforming the request.
|
||||
- `beforeTransformResponse`: Executed before transforming the response.
|
||||
- `afterTransformResponse`: Executed after transforming the response.
|
||||
|
||||
### Enabling Plugins
|
||||
|
||||
To use a plugin:
|
||||
|
||||
1. Place your plugin's JavaScript file (e.g., `my-plugin.js`) in the `~/.claude-code-router/plugins/` directory.
|
||||
2. Specify the plugin name (without the `.js` extension) in your `~/.claude-code-router/config.json` file using the `usePlugins` option:
|
||||
|
||||
```json
|
||||
// ~/.claud-code-router/config.json
|
||||
// ~/.claude-code-router/config.json
|
||||
{
|
||||
...,
|
||||
"usePlugins": ["notebook-tools-filter", "toolcall-improvement"]
|
||||
"usePlugins": ["my-plugin", "another-plugin"],
|
||||
|
||||
// or use plugins for a specific provider
|
||||
"Providers": [
|
||||
{
|
||||
"name": "gemini",
|
||||
"api_base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"api_key": "xxx",
|
||||
"models": ["gemini-2.5-flash"],
|
||||
"usePlugins": ["gemini"]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Currently, the following plugins are available:
|
||||
### Available Plugins
|
||||
|
||||
Currently, the following plugins are available:
|
||||
|
||||
- **notebook-tools-filter**
|
||||
This plugin filters out tool calls related to Jupyter notebooks (.ipynb files). You can use it if your work does not involve Jupyter.
|
||||
|
||||
- **gemini**
|
||||
Add support for the Google Gemini API endpoint: `https://generativelanguage.googleapis.com/v1beta/openai/`.
|
||||
|
||||
- **toolcall-improvement**
|
||||
If your LLM doesn’t handle tool usage well (for example, always returning code as plain text instead of modifying files — such as with deepseek-v3), you can use this plugin.
|
||||
This plugin simply adds the following system prompt. If you have a better prompt, you can modify it.
|
||||
|
||||
```markdown
|
||||
## **Important Instruction:**
|
||||
|
||||
You must use tools as frequently and accurately as possible to help the user solve their problem.
|
||||
Prioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response.
|
||||
```
|
||||
|
||||
|
||||
## Github Actions
|
||||
|
||||
You just need to install `Claude Code Actions` in your repository according to the [official documentation](https://docs.anthropic.com/en/docs/claude-code/github-actions). For `ANTHROPIC_API_KEY`, you can use any string. Then, modify your `.github/workflows/claude.yaml` file to include claude-code-router, like this:
|
||||
|
||||
```yaml
|
||||
name: Claude Code
|
||||
|
||||
@@ -179,6 +218,7 @@ jobs:
|
||||
with:
|
||||
anthropic_api_key: "test"
|
||||
```
|
||||
|
||||
You can modify the contents of `$HOME/.claude-code-router/config.json` as needed.
|
||||
GitHub Actions support allows you to trigger Claude Code at specific times, which opens up some interesting possibilities.
|
||||
|
||||
@@ -190,7 +230,6 @@ For example, between 00:30 and 08:30 Beijing Time, using the official DeepSeek A
|
||||
|
||||
So maybe in the future, I’ll describe detailed tasks for Claude Code ahead of time and let it run during these discounted hours to reduce costs?
|
||||
|
||||
|
||||
## Some tips:
|
||||
|
||||
Now you can use deepseek-v3 models directly without using any plugins.
|
||||
@@ -228,6 +267,6 @@ Thanks to the following sponsors:
|
||||
[@duanshuaimin](https://github.com/duanshuaimin)
|
||||
[@vrgitadmin](https://github.com/vrgitadmin)
|
||||
@*o (可通过主页邮箱联系我修改 github 用户名)
|
||||
@**聪 (可通过主页邮箱联系我修改github用户名)
|
||||
@\*\*聪 (可通过主页邮箱联系我修改 github 用户名)
|
||||
@*说 (可通过主页邮箱联系我修改 github 用户名)
|
||||
@*更 (可通过主页邮箱联系我修改github用户名)
|
||||
@\*更 (可通过主页邮箱联系我修改 github 用户名)
|
||||
|
||||
33
plugins/gemini.js
Normal file
33
plugins/gemini.js
Normal file
@@ -0,0 +1,33 @@
|
||||
module.exports = {
|
||||
afterTransformRequest(req, res) {
|
||||
if (Array.isArray(req.body.tools)) {
|
||||
// rewrite tools definition
|
||||
req.body.tools.forEach((tool) => {
|
||||
if (tool.function.name === "BatchTool") {
|
||||
// HACK: Gemini does not support objects with empty properties
|
||||
tool.function.parameters.properties.invocations.items.properties.input.type =
|
||||
"number";
|
||||
return;
|
||||
}
|
||||
Object.keys(tool.function.parameters.properties).forEach((key) => {
|
||||
const prop = tool.function.parameters.properties[key];
|
||||
if (
|
||||
prop.type === "string" &&
|
||||
!["enum", "date-time"].includes(prop.format)
|
||||
) {
|
||||
delete prop.format;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
if (req.body?.messages?.length) {
|
||||
req.body.messages.forEach((message) => {
|
||||
if (message.content === null) {
|
||||
if (message.tool_calls) {
|
||||
message.content = JSON.stringify(message.tool_calls);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -1,7 +1,12 @@
|
||||
module.exports = async function handle(req, res) {
|
||||
module.exports = {
|
||||
beforeRouter(req, res) {
|
||||
if (req?.body?.tools?.length) {
|
||||
req.body.tools = req.body.tools.filter(
|
||||
(tool) => !["NotebookRead", "NotebookEdit", "mcp__ide__executeCode"].includes(tool.name)
|
||||
(tool) =>
|
||||
!["NotebookRead", "NotebookEdit", "mcp__ide__executeCode"].includes(
|
||||
tool.name
|
||||
)
|
||||
);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
module.exports = async function handle(req, res) {
|
||||
module.exports = {
|
||||
afterTransformRequest(req, res) {
|
||||
if (req?.body?.tools?.length) {
|
||||
req.body.system.push({
|
||||
type: "text",
|
||||
text: `## **Important Instruction:** \nYou must use tools as frequently and accurately as possible to help the user solve their problem.\nPrioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response.`
|
||||
})
|
||||
req.body.messages.push({
|
||||
role: "system",
|
||||
content: `## **Important Instruction:** \nYou must use tools as frequently and accurately as possible to help the user solve their problem.\nPrioritize tool usage whenever it can enhance accuracy, efficiency, or the quality of the response. `,
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
27
src/index.ts
27
src/index.ts
@@ -3,7 +3,6 @@ import { writeFile } from "fs/promises";
|
||||
import { getOpenAICommonOptions, initConfig, initDir } from "./utils";
|
||||
import { createServer } from "./server";
|
||||
import { formatRequest } from "./middlewares/formatRequest";
|
||||
import { rewriteBody } from "./middlewares/rewriteBody";
|
||||
import { router } from "./middlewares/router";
|
||||
import OpenAI from "openai";
|
||||
import { streamOpenAIResponse } from "./utils/stream";
|
||||
@@ -14,6 +13,11 @@ import {
|
||||
} from "./utils/processCheck";
|
||||
import { LRUCache } from "lru-cache";
|
||||
import { log } from "./utils/log";
|
||||
import {
|
||||
loadPlugins,
|
||||
PLUGINS,
|
||||
usePluginMiddleware,
|
||||
} from "./middlewares/plugin";
|
||||
|
||||
async function initializeClaudeConfig() {
|
||||
const homeDir = process.env.HOME;
|
||||
@@ -44,6 +48,7 @@ interface ModelProvider {
|
||||
api_base_url: string;
|
||||
api_key: string;
|
||||
models: string[];
|
||||
usePlugins?: string[];
|
||||
}
|
||||
|
||||
async function run(options: RunOptions = {}) {
|
||||
@@ -56,6 +61,7 @@ async function run(options: RunOptions = {}) {
|
||||
await initializeClaudeConfig();
|
||||
await initDir();
|
||||
const config = await initConfig();
|
||||
await loadPlugins(config.usePlugins || []);
|
||||
|
||||
const Providers = new Map<string, ModelProvider>();
|
||||
const providerCache = new LRUCache<string, OpenAI>({
|
||||
@@ -63,7 +69,7 @@ async function run(options: RunOptions = {}) {
|
||||
ttl: 2 * 60 * 60 * 1000,
|
||||
});
|
||||
|
||||
function getProviderInstance(providerName: string): OpenAI {
|
||||
async function getProviderInstance(providerName: string): Promise<OpenAI> {
|
||||
const provider: ModelProvider | undefined = Providers.get(providerName);
|
||||
if (provider === undefined) {
|
||||
throw new Error(`Provider ${providerName} not found`);
|
||||
@@ -77,6 +83,10 @@ async function run(options: RunOptions = {}) {
|
||||
});
|
||||
providerCache.set(provider.name, openai);
|
||||
}
|
||||
const plugins = provider.usePlugins || [];
|
||||
if (plugins.length > 0) {
|
||||
await loadPlugins(plugins.map((name) => `${providerName},${name}`));
|
||||
}
|
||||
return openai;
|
||||
}
|
||||
|
||||
@@ -130,7 +140,7 @@ async function run(options: RunOptions = {}) {
|
||||
req.config = config;
|
||||
next();
|
||||
});
|
||||
server.useMiddleware(rewriteBody);
|
||||
server.useMiddleware(usePluginMiddleware("beforeRouter"));
|
||||
if (
|
||||
config.Router?.background &&
|
||||
config.Router?.think &&
|
||||
@@ -144,15 +154,22 @@ async function run(options: RunOptions = {}) {
|
||||
next();
|
||||
});
|
||||
}
|
||||
server.useMiddleware(usePluginMiddleware("afterRouter"));
|
||||
server.useMiddleware(usePluginMiddleware("beforeTransformRequest"));
|
||||
server.useMiddleware(formatRequest);
|
||||
server.useMiddleware(usePluginMiddleware("afterTransformRequest"));
|
||||
|
||||
server.app.post("/v1/messages", async (req, res) => {
|
||||
try {
|
||||
const provider = getProviderInstance(req.provider || "default");
|
||||
const provider = await getProviderInstance(req.provider || "default");
|
||||
log("final request body:", req.body);
|
||||
const completion: any = await provider.chat.completions.create(req.body);
|
||||
await streamOpenAIResponse(res, completion, req.body.model, req.body);
|
||||
await streamOpenAIResponse(req, res, completion);
|
||||
} catch (e) {
|
||||
log("Error in OpenAI API call:", e);
|
||||
res.status(500).json({
|
||||
error: e.message,
|
||||
});
|
||||
}
|
||||
});
|
||||
server.start();
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
|
||||
import OpenAI from "openai";
|
||||
import { streamOpenAIResponse } from "../utils/stream";
|
||||
import { log } from "../utils/log";
|
||||
|
||||
export const formatRequest = async (
|
||||
@@ -19,7 +18,7 @@ export const formatRequest = async (
|
||||
tools,
|
||||
stream,
|
||||
}: MessageCreateParamsBase = req.body;
|
||||
log("formatRequest: ", req.body);
|
||||
log("beforeTransformRequest: ", req.body);
|
||||
try {
|
||||
// @ts-ignore
|
||||
const openAIMessages = Array.isArray(messages)
|
||||
@@ -50,6 +49,7 @@ export const formatRequest = async (
|
||||
|
||||
anthropicMessage.content.forEach((contentPart) => {
|
||||
if (contentPart.type === "text") {
|
||||
if (contentPart.text.includes("(no content)")) return;
|
||||
textContent +=
|
||||
(typeof contentPart.text === "string"
|
||||
? contentPart.text
|
||||
@@ -112,17 +112,18 @@ export const formatRequest = async (
|
||||
});
|
||||
|
||||
const trimmedUserText = userTextMessageContent.trim();
|
||||
// @ts-ignore
|
||||
openAiMessagesFromThisAnthropicMessage.push(
|
||||
// @ts-ignore
|
||||
...subsequentToolMessages
|
||||
);
|
||||
|
||||
if (trimmedUserText.length > 0) {
|
||||
openAiMessagesFromThisAnthropicMessage.push({
|
||||
role: "user",
|
||||
content: trimmedUserText,
|
||||
});
|
||||
}
|
||||
// @ts-ignore
|
||||
openAiMessagesFromThisAnthropicMessage.push(
|
||||
// @ts-ignore
|
||||
...subsequentToolMessages
|
||||
);
|
||||
} else {
|
||||
// Fallback for other roles (e.g. system, or custom roles if they were to appear here with array content)
|
||||
// This will combine all text parts into a single message for that role.
|
||||
@@ -180,30 +181,9 @@ export const formatRequest = async (
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
req.body = data;
|
||||
console.log(JSON.stringify(data.messages, null, 2));
|
||||
log("afterTransformRequest: ", req.body);
|
||||
} catch (error) {
|
||||
console.error("Error in request processing:", error);
|
||||
const errorCompletion: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> =
|
||||
{
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: `error_${Date.now()}`,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model,
|
||||
object: "chat.completion.chunk",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: `Error: ${(error as Error).message}`,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
};
|
||||
await streamOpenAIResponse(res, errorCompletion, model, req.body);
|
||||
log("Error in TransformRequest:", error);
|
||||
}
|
||||
next();
|
||||
};
|
||||
|
||||
106
src/middlewares/plugin.ts
Normal file
106
src/middlewares/plugin.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import Module from "node:module";
|
||||
import { streamOpenAIResponse } from "../utils/stream";
|
||||
import { log } from "../utils/log";
|
||||
import { PLUGINS_DIR } from "../constants";
|
||||
import path from "node:path";
|
||||
import { access } from "node:fs/promises";
|
||||
import { OpenAI } from "openai";
|
||||
import { createClient } from "../utils";
|
||||
import { Response } from "express";
|
||||
|
||||
// @ts-ignore
|
||||
const originalLoad = Module._load;
|
||||
// @ts-ignore
|
||||
Module._load = function (request, parent, isMain) {
|
||||
if (request === "claude-code-router") {
|
||||
return {
|
||||
streamOpenAIResponse,
|
||||
log,
|
||||
OpenAI,
|
||||
createClient,
|
||||
};
|
||||
}
|
||||
return originalLoad.call(this, request, parent, isMain);
|
||||
};
|
||||
|
||||
export type PluginHook =
|
||||
| "beforeRouter"
|
||||
| "afterRouter"
|
||||
| "beforeTransformRequest"
|
||||
| "afterTransformRequest"
|
||||
| "beforeTransformResponse"
|
||||
| "afterTransformResponse";
|
||||
|
||||
export interface Plugin {
|
||||
beforeRouter?: (req: any, res: Response) => Promise<any>;
|
||||
afterRouter?: (req: any, res: Response) => Promise<any>;
|
||||
|
||||
beforeTransformRequest?: (req: any, res: Response) => Promise<any>;
|
||||
afterTransformRequest?: (req: any, res: Response) => Promise<any>;
|
||||
|
||||
beforeTransformResponse?: (
|
||||
req: any,
|
||||
res: Response,
|
||||
data?: { completion: any }
|
||||
) => Promise<any>;
|
||||
afterTransformResponse?: (
|
||||
req: any,
|
||||
res: Response,
|
||||
data?: { completion: any; transformedCompletion: any }
|
||||
) => Promise<any>;
|
||||
}
|
||||
|
||||
export const PLUGINS = new Map<string, Plugin>();
|
||||
|
||||
const loadPlugin = async (pluginName: string) => {
|
||||
const filePath = pluginName.split(",").pop();
|
||||
const pluginPath = path.join(PLUGINS_DIR, `${filePath}.js`);
|
||||
try {
|
||||
await access(pluginPath);
|
||||
const plugin = require(pluginPath);
|
||||
if (
|
||||
[
|
||||
"beforeRouter",
|
||||
"afterRouter",
|
||||
"beforeTransformRequest",
|
||||
"afterTransformRequest",
|
||||
"beforeTransformResponse",
|
||||
"afterTransformResponse",
|
||||
].some((key) => key in plugin)
|
||||
) {
|
||||
PLUGINS.set(pluginName, plugin);
|
||||
log(`Plugin ${pluginName} loaded successfully.`);
|
||||
} else {
|
||||
throw new Error(`Plugin ${pluginName} does not export a function.`);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(`Failed to load plugin ${pluginName}:`, e);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
export const loadPlugins = async (pluginNames: string[]) => {
|
||||
console.log("Loading plugins:", pluginNames);
|
||||
for (const file of pluginNames) {
|
||||
await loadPlugin(file);
|
||||
}
|
||||
};
|
||||
|
||||
export const usePluginMiddleware = (type: PluginHook) => {
|
||||
return async (req: any, res: Response, next: any) => {
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin[type]) {
|
||||
try {
|
||||
await plugin[type](req, res);
|
||||
log(`Plugin ${name} executed hook: ${type}`);
|
||||
} catch (error) {
|
||||
log(`Error in plugin ${name} during hook ${type}:`, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
next();
|
||||
};
|
||||
};
|
||||
@@ -1,45 +0,0 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import Module from "node:module";
|
||||
import { streamOpenAIResponse } from "../utils/stream";
|
||||
import { log } from "../utils/log";
|
||||
import { PLUGINS_DIR } from "../constants";
|
||||
import path from "node:path";
|
||||
import { access } from "node:fs/promises";
|
||||
import { OpenAI } from "openai";
|
||||
import { createClient } from "../utils";
|
||||
|
||||
// @ts-ignore
|
||||
const originalLoad = Module._load;
|
||||
// @ts-ignore
|
||||
Module._load = function (request, parent, isMain) {
|
||||
if (request === "claude-code-router") {
|
||||
return {
|
||||
streamOpenAIResponse,
|
||||
log,
|
||||
OpenAI,
|
||||
createClient,
|
||||
};
|
||||
}
|
||||
return originalLoad.call(this, request, parent, isMain);
|
||||
};
|
||||
|
||||
export const rewriteBody = async (
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) => {
|
||||
if (!req.config.usePlugins) {
|
||||
return next();
|
||||
}
|
||||
for (const plugin of req.config.usePlugins) {
|
||||
const pluginPath = path.join(PLUGINS_DIR, `${plugin.trim()}.js`);
|
||||
try {
|
||||
await access(pluginPath);
|
||||
const rewritePlugin = require(pluginPath);
|
||||
await rewritePlugin(req, res);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
next();
|
||||
};
|
||||
@@ -41,9 +41,11 @@ const getUseModel = (req: Request, tokenCount: number) => {
|
||||
model,
|
||||
};
|
||||
}
|
||||
const [defaultProvider, defaultModel] =
|
||||
req.config.Router!.default?.split(",");
|
||||
return {
|
||||
provider: "default",
|
||||
model: req.config.OPENAI_MODEL,
|
||||
provider: defaultProvider || "default",
|
||||
model: defaultModel || req.config.OPENAI_MODEL,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
HOME_DIR,
|
||||
PLUGINS_DIR,
|
||||
} from "../constants";
|
||||
import crypto from "node:crypto";
|
||||
|
||||
export function getOpenAICommonOptions(): ClientOptions {
|
||||
const options: ClientOptions = {};
|
||||
@@ -90,3 +91,7 @@ export const createClient = (options: ClientOptions) => {
|
||||
});
|
||||
return client;
|
||||
};
|
||||
|
||||
export const sha256 = (data: string | Buffer): string => {
|
||||
return crypto.createHash("sha256").update(data).digest("hex");
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ if (!fs.existsSync(HOME_DIR)) {
|
||||
|
||||
export function log(...args: any[]) {
|
||||
// Check if logging is enabled via environment variable
|
||||
// console.log(...args); // Log to console for immediate feedback
|
||||
const isLogEnabled = process.env.LOG === "true";
|
||||
|
||||
if (!isLogEnabled) {
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import { Response } from "express";
|
||||
import { OpenAI } from "openai";
|
||||
import { Request, Response } from "express";
|
||||
import { log } from "./log";
|
||||
import { PLUGINS } from "../middlewares/plugin";
|
||||
import { sha256 } from ".";
|
||||
|
||||
declare module "express" {
|
||||
interface Request {
|
||||
provider?: string;
|
||||
}
|
||||
}
|
||||
|
||||
interface ContentBlock {
|
||||
type: string;
|
||||
@@ -42,28 +49,61 @@ interface MessageEvent {
|
||||
}
|
||||
|
||||
export async function streamOpenAIResponse(
|
||||
req: Request,
|
||||
res: Response,
|
||||
completion: any,
|
||||
model: string,
|
||||
body: any
|
||||
_completion: any
|
||||
) {
|
||||
const write = (data: string) => {
|
||||
log("response: ", data);
|
||||
res.write(data);
|
||||
let completion = _completion;
|
||||
res.locals.completion = completion;
|
||||
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.beforeTransformResponse) {
|
||||
const result = await plugin.beforeTransformResponse(req, res, {
|
||||
completion,
|
||||
});
|
||||
if (result) {
|
||||
completion = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
const write = async (data: string) => {
|
||||
let eventData = data;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: eventData,
|
||||
});
|
||||
if (typeof hookResult === "string") {
|
||||
eventData = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eventData) {
|
||||
log("response: ", eventData);
|
||||
res.write(eventData);
|
||||
}
|
||||
};
|
||||
const messageId = "msg_" + Date.now();
|
||||
if (!body.stream) {
|
||||
if (!req.body.stream) {
|
||||
let content: any = [];
|
||||
if (completion.choices[0].message.content) {
|
||||
content = [{ text: completion.choices[0].message.content, type: "text" }];
|
||||
}
|
||||
else if (completion.choices[0].message.tool_calls) {
|
||||
} else if (completion.choices[0].message.tool_calls) {
|
||||
content = completion.choices[0].message.tool_calls.map((item: any) => {
|
||||
return {
|
||||
type: 'tool_use',
|
||||
type: "tool_use",
|
||||
id: item.id,
|
||||
name: item.function?.name,
|
||||
input: item.function?.arguments ? JSON.parse(item.function.arguments) : {},
|
||||
input: item.function?.arguments
|
||||
? JSON.parse(item.function.arguments)
|
||||
: {},
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -74,11 +114,29 @@ export async function streamOpenAIResponse(
|
||||
role: "assistant",
|
||||
// @ts-ignore
|
||||
content: content,
|
||||
stop_reason: completion.choices[0].finish_reason === 'tool_calls' ? "tool_use" : "end_turn",
|
||||
stop_reason:
|
||||
completion.choices[0].finish_reason === "tool_calls"
|
||||
? "tool_use"
|
||||
: "end_turn",
|
||||
stop_sequence: null,
|
||||
};
|
||||
try {
|
||||
res.json(result);
|
||||
res.locals.transformedCompletion = result;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: res.locals.transformedCompletion,
|
||||
});
|
||||
if (hookResult) {
|
||||
res.locals.transformedCompletion = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
res.json(res.locals.transformedCompletion);
|
||||
res.end();
|
||||
return;
|
||||
} catch (error) {
|
||||
@@ -98,7 +156,7 @@ export async function streamOpenAIResponse(
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [],
|
||||
model,
|
||||
model: req.body.model,
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: { input_tokens: 1, output_tokens: 1 },
|
||||
@@ -118,11 +176,13 @@ export async function streamOpenAIResponse(
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
if (delta.tool_calls && delta.tool_calls.length > 0) {
|
||||
for (const toolCall of delta.tool_calls) {
|
||||
const toolCallId = toolCall.id;
|
||||
// Handle each tool call in the current chunk
|
||||
for (const [index, toolCall] of delta.tool_calls.entries()) {
|
||||
// Generate a stable ID for this tool call position
|
||||
const toolCallId = toolCall.id || `tool_${index}`;
|
||||
|
||||
// Check if this is a new tool call by ID
|
||||
if (toolCallId && toolCallId !== currentToolCallId) {
|
||||
// If this position doesn't have an active tool call, start a new one
|
||||
if (!toolCallJsonMap.has(`${index}`)) {
|
||||
// End previous tool call if one was active
|
||||
if (isToolUse && currentToolCallId) {
|
||||
const contentBlockStop: MessageEvent = {
|
||||
@@ -138,13 +198,13 @@ export async function streamOpenAIResponse(
|
||||
|
||||
// Start new tool call block
|
||||
isToolUse = true;
|
||||
currentToolCallId = toolCallId;
|
||||
currentToolCallId = `${index}`;
|
||||
contentBlockIndex++;
|
||||
toolCallJsonMap.set(toolCallId, ""); // Initialize JSON accumulator for this tool call
|
||||
toolCallJsonMap.set(`${index}`, ""); // Initialize JSON accumulator for this tool call
|
||||
|
||||
const toolBlock: ContentBlock = {
|
||||
type: "tool_use",
|
||||
id: toolCallId,
|
||||
id: toolCallId, // Use the original ID if available
|
||||
name: toolCall.function?.name,
|
||||
input: {},
|
||||
};
|
||||
@@ -164,8 +224,8 @@ export async function streamOpenAIResponse(
|
||||
);
|
||||
}
|
||||
|
||||
// Stream tool call JSON
|
||||
if (toolCall.function?.arguments && currentToolCallId) {
|
||||
// Stream tool call JSON for this position
|
||||
if (toolCall.function?.arguments) {
|
||||
const jsonDelta: MessageEvent = {
|
||||
type: "content_block_delta",
|
||||
index: contentBlockIndex,
|
||||
@@ -175,27 +235,39 @@ export async function streamOpenAIResponse(
|
||||
},
|
||||
};
|
||||
|
||||
// Accumulate JSON for this specific tool call
|
||||
const currentJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
toolCallJsonMap.set(currentToolCallId, currentJson + toolCall.function.arguments);
|
||||
toolUseJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
// Accumulate JSON for this specific tool call position
|
||||
const currentJson = toolCallJsonMap.get(`${index}`) || "";
|
||||
const newJson = currentJson + toolCall.function.arguments;
|
||||
toolCallJsonMap.set(`${index}`, newJson);
|
||||
|
||||
// Try to parse accumulated JSON
|
||||
if (isValidJson(newJson)) {
|
||||
try {
|
||||
const parsedJson = JSON.parse(toolUseJson);
|
||||
currentContentBlocks[contentBlockIndex].input = parsedJson;
|
||||
const parsedJson = JSON.parse(newJson);
|
||||
const blockIndex = currentContentBlocks.findIndex(
|
||||
(block) => block.type === "tool_use" && block.id === toolCallId
|
||||
);
|
||||
if (blockIndex !== -1) {
|
||||
currentContentBlocks[blockIndex].input = parsedJson;
|
||||
}
|
||||
} catch (e) {
|
||||
log("JSON parsing error (continuing to accumulate):", e);
|
||||
// JSON not yet complete, continue accumulating
|
||||
}
|
||||
}
|
||||
|
||||
write(
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n`
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(
|
||||
jsonDelta
|
||||
)}\n\n`
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (delta.content) {
|
||||
// Handle regular text content
|
||||
if (isToolUse) {
|
||||
} else if (delta.content || chunk.choices[0].finish_reason) {
|
||||
// Handle regular text content or completion
|
||||
if (
|
||||
isToolUse &&
|
||||
(delta.content || chunk.choices[0].finish_reason === "tool_calls")
|
||||
) {
|
||||
log("Tool call ended here:", delta);
|
||||
// End previous tool call block
|
||||
const contentBlockStop: MessageEvent = {
|
||||
@@ -214,8 +286,6 @@ export async function streamOpenAIResponse(
|
||||
toolUseJson = ""; // Reset for safety
|
||||
}
|
||||
|
||||
if (!delta.content) continue;
|
||||
|
||||
// If text block not yet started, send content_block_start
|
||||
if (!hasStartedTextBlock) {
|
||||
const textBlock: ContentBlock = {
|
||||
@@ -317,18 +387,34 @@ export async function streamOpenAIResponse(
|
||||
);
|
||||
}
|
||||
|
||||
res.locals.transformedCompletion = currentContentBlocks;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: res.locals.transformedCompletion,
|
||||
});
|
||||
if (hookResult) {
|
||||
res.locals.transformedCompletion = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send message_delta event with appropriate stop_reason
|
||||
const messageDelta: MessageEvent = {
|
||||
type: "message_delta",
|
||||
delta: {
|
||||
stop_reason: isToolUse ? "tool_use" : "end_turn",
|
||||
stop_sequence: null,
|
||||
content: currentContentBlocks,
|
||||
content: res.locals.transformedCompletion,
|
||||
},
|
||||
usage: { input_tokens: 100, output_tokens: 150 },
|
||||
};
|
||||
if (!isToolUse) {
|
||||
log("body: ", body, "messageDelta: ", messageDelta);
|
||||
log("body: ", req.body, "messageDelta: ", messageDelta);
|
||||
}
|
||||
|
||||
write(`event: message_delta\ndata: ${JSON.stringify(messageDelta)}\n\n`);
|
||||
@@ -341,3 +427,42 @@ export async function streamOpenAIResponse(
|
||||
write(`event: message_stop\ndata: ${JSON.stringify(messageStop)}\n\n`);
|
||||
res.end();
|
||||
}
|
||||
|
||||
// Add helper function at the top of the file
|
||||
function isValidJson(str: string): boolean {
|
||||
// Check if the string contains both opening and closing braces/brackets
|
||||
const hasOpenBrace = str.includes("{");
|
||||
const hasCloseBrace = str.includes("}");
|
||||
const hasOpenBracket = str.includes("[");
|
||||
const hasCloseBracket = str.includes("]");
|
||||
|
||||
// Check if we have matching pairs
|
||||
if ((hasOpenBrace && !hasCloseBrace) || (!hasOpenBrace && hasCloseBrace)) {
|
||||
return false;
|
||||
}
|
||||
if (
|
||||
(hasOpenBracket && !hasCloseBracket) ||
|
||||
(!hasOpenBracket && hasCloseBracket)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Count nested braces/brackets
|
||||
let braceCount = 0;
|
||||
let bracketCount = 0;
|
||||
|
||||
for (const char of str) {
|
||||
if (char === "{") braceCount++;
|
||||
if (char === "}") braceCount--;
|
||||
if (char === "[") bracketCount++;
|
||||
if (char === "]") bracketCount--;
|
||||
|
||||
// If we ever go negative, the JSON is invalid
|
||||
if (braceCount < 0 || bracketCount < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// All braces/brackets should be matched
|
||||
return braceCount === 0 && bracketCount === 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user