Files
claude-code-router/packages/server/src/utils/router.ts
2025-12-25 23:00:24 +08:00

318 lines
9.3 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { get_encoding } from "tiktoken";
import { sessionUsageCache, Usage } from "./cache";
import { readFile, access } from "fs/promises";
import { opendir, stat } from "fs/promises";
import { join } from "path";
import { CLAUDE_PROJECTS_DIR, HOME_DIR } from "@CCR/shared";
import { LRUCache } from "lru-cache";
// Types from @anthropic-ai/sdk
interface Tool {
name: string;
description?: string;
input_schema: object;
}
interface ContentBlockParam {
type: string;
[key: string]: any;
}
interface MessageParam {
role: string;
content: string | ContentBlockParam[];
}
interface MessageCreateParamsBase {
messages?: MessageParam[];
system?: string | any[];
tools?: Tool[];
[key: string]: any;
}
const enc = get_encoding("cl100k_base");
export const calculateTokenCount = (
messages: MessageParam[],
system: any,
tools: Tool[]
) => {
let tokenCount = 0;
if (Array.isArray(messages)) {
messages.forEach((message) => {
if (typeof message.content === "string") {
tokenCount += enc.encode(message.content).length;
} else if (Array.isArray(message.content)) {
message.content.forEach((contentPart: any) => {
if (contentPart.type === "text") {
tokenCount += enc.encode(contentPart.text).length;
} else if (contentPart.type === "tool_use") {
tokenCount += enc.encode(JSON.stringify(contentPart.input)).length;
} else if (contentPart.type === "tool_result") {
tokenCount += enc.encode(
typeof contentPart.content === "string"
? contentPart.content
: JSON.stringify(contentPart.content)
).length;
}
});
}
});
}
if (typeof system === "string") {
tokenCount += enc.encode(system).length;
} else if (Array.isArray(system)) {
system.forEach((item: any) => {
if (item.type !== "text") return;
if (typeof item.text === "string") {
tokenCount += enc.encode(item.text).length;
} else if (Array.isArray(item.text)) {
item.text.forEach((textPart: any) => {
tokenCount += enc.encode(textPart || "").length;
});
}
});
}
if (tools) {
tools.forEach((tool: Tool) => {
if (tool.description) {
tokenCount += enc.encode(tool.name + tool.description).length;
}
if (tool.input_schema) {
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
}
});
}
return tokenCount;
};
const readConfigFile = async (filePath: string) => {
try {
await access(filePath);
const content = await readFile(filePath, "utf8");
return JSON.parse(content);
} catch (error) {
return null; // 文件不存在或读取失败时返回null
}
};
const getProjectSpecificRouter = async (req: any) => {
// 检查是否有项目特定的配置
if (req.sessionId) {
const project = await searchProjectBySession(req.sessionId);
if (project) {
const projectConfigPath = join(HOME_DIR, project, "config.json");
const sessionConfigPath = join(
HOME_DIR,
project,
`${req.sessionId}.json`
);
// 首先尝试读取sessionConfig文件
const sessionConfig = await readConfigFile(sessionConfigPath);
if (sessionConfig && sessionConfig.Router) {
return sessionConfig.Router;
}
const projectConfig = await readConfigFile(projectConfigPath);
if (projectConfig && projectConfig.Router) {
return projectConfig.Router;
}
}
}
return undefined; // 返回undefined表示使用原始配置
};
const getUseModel = async (
req: any,
tokenCount: number,
config: any,
lastUsage?: Usage | undefined
) => {
const projectSpecificRouter = await getProjectSpecificRouter(req);
const Router = projectSpecificRouter || config.Router;
if (req.body.model.includes(",")) {
const [provider, model] = req.body.model.split(",");
const finalProvider = config.Providers.find(
(p: any) => p.name.toLowerCase() === provider
);
const finalModel = finalProvider?.models?.find(
(m: any) => m.toLowerCase() === model
);
if (finalProvider && finalModel) {
return `${finalProvider.name},${finalModel}`;
}
return req.body.model;
}
// if tokenCount is greater than the configured threshold, use the long context model
const longContextThreshold = Router.longContextThreshold || 60000;
const lastUsageThreshold =
lastUsage &&
lastUsage.input_tokens > longContextThreshold &&
tokenCount > 20000;
const tokenCountThreshold = tokenCount > longContextThreshold;
if ((lastUsageThreshold || tokenCountThreshold) && Router.longContext) {
req.log.info(
`Using long context model due to token count: ${tokenCount}, threshold: ${longContextThreshold}`
);
return Router.longContext;
}
if (
req.body?.system?.length > 1 &&
req.body?.system[1]?.text?.startsWith("<CCR-SUBAGENT-MODEL>")
) {
const model = req.body?.system[1].text.match(
/<CCR-SUBAGENT-MODEL>(.*?)<\/CCR-SUBAGENT-MODEL>/s
);
if (model) {
req.body.system[1].text = req.body.system[1].text.replace(
`<CCR-SUBAGENT-MODEL>${model[1]}</CCR-SUBAGENT-MODEL>`,
""
);
return model[1];
}
}
// Use the background model for any Claude Haiku variant
if (
req.body.model?.includes("claude") &&
req.body.model?.includes("haiku") &&
config.Router.background
) {
req.log.info(`Using background model for ${req.body.model}`);
return config.Router.background;
}
// The priority of websearch must be higher than thinking.
if (
Array.isArray(req.body.tools) &&
req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) &&
Router.webSearch
) {
return Router.webSearch;
}
// if exits thinking, use the think model
if (req.body.thinking && Router.think) {
req.log.info(`Using think model for ${req.body.thinking}`);
return Router.think;
}
return Router!.default;
};
export const router = async (req: any, _res: any, context: any) => {
const { config, event } = context;
// Parse sessionId from metadata.user_id
if (req.body.metadata?.user_id) {
const parts = req.body.metadata.user_id.split("_session_");
if (parts.length > 1) {
req.sessionId = parts[1];
}
}
const lastMessageUsage = sessionUsageCache.get(req.sessionId);
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
if (
config.REWRITE_SYSTEM_PROMPT &&
system.length > 1 &&
system[1]?.text?.includes("<env>")
) {
const prompt = await readFile(config.REWRITE_SYSTEM_PROMPT, "utf-8");
system[1].text = `${prompt}<env>${system[1].text.split("<env>").pop()}`;
}
try {
const tokenCount = calculateTokenCount(
messages as MessageParam[],
system,
tools as Tool[]
);
let model;
if (config.CUSTOM_ROUTER_PATH) {
try {
const customRouter = require(config.CUSTOM_ROUTER_PATH);
req.tokenCount = tokenCount; // Pass token count to custom router
model = await customRouter(req, config, {
event,
});
} catch (e: any) {
req.log.error(`failed to load custom router: ${e.message}`);
}
}
if (!model) {
model = await getUseModel(req, tokenCount, config, lastMessageUsage);
}
req.body.model = model;
} catch (error: any) {
req.log.error(`Error in router middleware: ${error.message}`);
req.body.model = config.Router!.default;
}
return;
};
// 内存缓存存储sessionId到项目名称的映射
// null值表示之前已查找过但未找到项目
// 使用LRU缓存限制最大1000个条目
const sessionProjectCache = new LRUCache<string, string>({
max: 1000,
});
export const searchProjectBySession = async (
sessionId: string
): Promise<string | null> => {
// 首先检查缓存
if (sessionProjectCache.has(sessionId)) {
const result = sessionProjectCache.get(sessionId);
if (!result || result === '') {
return null;
}
return result;
}
try {
const dir = await opendir(CLAUDE_PROJECTS_DIR);
const folderNames: string[] = [];
// 收集所有文件夹名称
for await (const dirent of dir) {
if (dirent.isDirectory()) {
folderNames.push(dirent.name);
}
}
// 并发检查每个项目文件夹中是否存在sessionId.jsonl文件
const checkPromises = folderNames.map(async (folderName) => {
const sessionFilePath = join(
CLAUDE_PROJECTS_DIR,
folderName,
`${sessionId}.jsonl`
);
try {
const fileStat = await stat(sessionFilePath);
return fileStat.isFile() ? folderName : null;
} catch {
// 文件不存在,继续检查下一个
return null;
}
});
const results = await Promise.all(checkPromises);
// 返回第一个存在的项目目录名称
for (const result of results) {
if (result) {
// 缓存找到的结果
sessionProjectCache.set(sessionId, result);
return result;
}
}
// 缓存未找到的结果null值表示之前已查找过但未找到项目
sessionProjectCache.set(sessionId, '');
return null; // 没有找到匹配的项目
} catch (error) {
console.error("Error searching for project by session:", error);
// 出错时也缓存null结果避免重复出错
sessionProjectCache.set(sessionId, '');
return null;
}
};