Files
claude-code-router/src/utils/router.ts
2025-07-21 10:46:53 +08:00

94 lines
3.2 KiB
TypeScript

import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
import { get_encoding } from "tiktoken";
import { log } from "./log";
const enc = get_encoding("cl100k_base");
const getUseModel = (req: any, tokenCount: number, config: any) => {
if (req.body.model.includes(",")) {
return req.body.model;
}
// if tokenCount is greater than 60K, use the long context model
if (tokenCount > 1000 * 60 && config.Router.longContext) {
log("Using long context model due to token count:", tokenCount);
return config.Router.longContext;
}
// If the model is claude-3-5-haiku, use the background model
if (
req.body.model?.startsWith("claude-3-5-haiku") &&
config.Router.background
) {
log("Using background model for ", req.body.model);
return config.Router.background;
}
// if exits thinking, use the think model
if (req.body.thinking && config.Router.think) {
log("Using think model for ", req.body.thinking);
return config.Router.think;
}
if (Array.isArray(req.body.tools) && req.body.tools.some(tool => tool.type?.startsWith('web_search')) && config.Router.webSearch) {
return config.Router.webSearch;
}
return config.Router!.default;
};
export const router = async (req: any, res: any, config: any) => {
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
try {
let tokenCount = 0;
if (Array.isArray(messages)) {
messages.forEach((message) => {
if (typeof message.content === "string") {
tokenCount += enc.encode(message.content).length;
} else if (Array.isArray(message.content)) {
message.content.forEach((contentPart) => {
if (contentPart.type === "text") {
tokenCount += enc.encode(contentPart.text).length;
} else if (contentPart.type === "tool_use") {
tokenCount += enc.encode(
JSON.stringify(contentPart.input)
).length;
} else if (contentPart.type === "tool_result") {
tokenCount += enc.encode(
typeof contentPart.content === "string"
? contentPart.content
: JSON.stringify(contentPart.content)
).length;
}
});
}
});
}
if (typeof system === "string") {
tokenCount += enc.encode(system).length;
} else if (Array.isArray(system)) {
system.forEach((item) => {
if (item.type !== "text") return;
if (typeof item.text === "string") {
tokenCount += enc.encode(item.text).length;
} else if (Array.isArray(item.text)) {
item.text.forEach((textPart) => {
tokenCount += enc.encode(textPart || "").length;
});
}
});
}
if (tools) {
tools.forEach((tool) => {
if (tool.description) {
tokenCount += enc.encode(tool.name + tool.description).length;
}
if (tool.input_schema) {
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
}
});
}
const model = getUseModel(req, tokenCount, config);
req.body.model = model;
} catch (error: any) {
log("Error in router middleware:", error.message);
req.body.model = config.Router!.default;
}
return;
};