mirror of
https://github.com/musistudio/claude-code-router.git
synced 2026-01-30 06:12:06 +00:00
318 lines
9.3 KiB
TypeScript
318 lines
9.3 KiB
TypeScript
import { get_encoding } from "tiktoken";
|
||
import { sessionUsageCache, Usage } from "./cache";
|
||
import { readFile, access } from "fs/promises";
|
||
import { opendir, stat } from "fs/promises";
|
||
import { join } from "path";
|
||
import { CLAUDE_PROJECTS_DIR, HOME_DIR } from "@CCR/shared";
|
||
import { LRUCache } from "lru-cache";
|
||
|
||
// Types from @anthropic-ai/sdk
|
||
interface Tool {
|
||
name: string;
|
||
description?: string;
|
||
input_schema: object;
|
||
}
|
||
|
||
interface ContentBlockParam {
|
||
type: string;
|
||
[key: string]: any;
|
||
}
|
||
|
||
interface MessageParam {
|
||
role: string;
|
||
content: string | ContentBlockParam[];
|
||
}
|
||
|
||
interface MessageCreateParamsBase {
|
||
messages?: MessageParam[];
|
||
system?: string | any[];
|
||
tools?: Tool[];
|
||
[key: string]: any;
|
||
}
|
||
|
||
const enc = get_encoding("cl100k_base");
|
||
|
||
export const calculateTokenCount = (
|
||
messages: MessageParam[],
|
||
system: any,
|
||
tools: Tool[]
|
||
) => {
|
||
let tokenCount = 0;
|
||
if (Array.isArray(messages)) {
|
||
messages.forEach((message) => {
|
||
if (typeof message.content === "string") {
|
||
tokenCount += enc.encode(message.content).length;
|
||
} else if (Array.isArray(message.content)) {
|
||
message.content.forEach((contentPart: any) => {
|
||
if (contentPart.type === "text") {
|
||
tokenCount += enc.encode(contentPart.text).length;
|
||
} else if (contentPart.type === "tool_use") {
|
||
tokenCount += enc.encode(JSON.stringify(contentPart.input)).length;
|
||
} else if (contentPart.type === "tool_result") {
|
||
tokenCount += enc.encode(
|
||
typeof contentPart.content === "string"
|
||
? contentPart.content
|
||
: JSON.stringify(contentPart.content)
|
||
).length;
|
||
}
|
||
});
|
||
}
|
||
});
|
||
}
|
||
if (typeof system === "string") {
|
||
tokenCount += enc.encode(system).length;
|
||
} else if (Array.isArray(system)) {
|
||
system.forEach((item: any) => {
|
||
if (item.type !== "text") return;
|
||
if (typeof item.text === "string") {
|
||
tokenCount += enc.encode(item.text).length;
|
||
} else if (Array.isArray(item.text)) {
|
||
item.text.forEach((textPart: any) => {
|
||
tokenCount += enc.encode(textPart || "").length;
|
||
});
|
||
}
|
||
});
|
||
}
|
||
if (tools) {
|
||
tools.forEach((tool: Tool) => {
|
||
if (tool.description) {
|
||
tokenCount += enc.encode(tool.name + tool.description).length;
|
||
}
|
||
if (tool.input_schema) {
|
||
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
|
||
}
|
||
});
|
||
}
|
||
return tokenCount;
|
||
};
|
||
|
||
const readConfigFile = async (filePath: string) => {
|
||
try {
|
||
await access(filePath);
|
||
const content = await readFile(filePath, "utf8");
|
||
return JSON.parse(content);
|
||
} catch (error) {
|
||
return null; // 文件不存在或读取失败时返回null
|
||
}
|
||
};
|
||
|
||
const getProjectSpecificRouter = async (req: any) => {
|
||
// 检查是否有项目特定的配置
|
||
if (req.sessionId) {
|
||
const project = await searchProjectBySession(req.sessionId);
|
||
if (project) {
|
||
const projectConfigPath = join(HOME_DIR, project, "config.json");
|
||
const sessionConfigPath = join(
|
||
HOME_DIR,
|
||
project,
|
||
`${req.sessionId}.json`
|
||
);
|
||
|
||
// 首先尝试读取sessionConfig文件
|
||
const sessionConfig = await readConfigFile(sessionConfigPath);
|
||
if (sessionConfig && sessionConfig.Router) {
|
||
return sessionConfig.Router;
|
||
}
|
||
const projectConfig = await readConfigFile(projectConfigPath);
|
||
if (projectConfig && projectConfig.Router) {
|
||
return projectConfig.Router;
|
||
}
|
||
}
|
||
}
|
||
return undefined; // 返回undefined表示使用原始配置
|
||
};
|
||
|
||
const getUseModel = async (
|
||
req: any,
|
||
tokenCount: number,
|
||
config: any,
|
||
lastUsage?: Usage | undefined
|
||
) => {
|
||
const projectSpecificRouter = await getProjectSpecificRouter(req);
|
||
const Router = projectSpecificRouter || config.Router;
|
||
|
||
if (req.body.model.includes(",")) {
|
||
const [provider, model] = req.body.model.split(",");
|
||
const finalProvider = config.Providers.find(
|
||
(p: any) => p.name.toLowerCase() === provider
|
||
);
|
||
const finalModel = finalProvider?.models?.find(
|
||
(m: any) => m.toLowerCase() === model
|
||
);
|
||
if (finalProvider && finalModel) {
|
||
return `${finalProvider.name},${finalModel}`;
|
||
}
|
||
return req.body.model;
|
||
}
|
||
|
||
// if tokenCount is greater than the configured threshold, use the long context model
|
||
const longContextThreshold = Router.longContextThreshold || 60000;
|
||
const lastUsageThreshold =
|
||
lastUsage &&
|
||
lastUsage.input_tokens > longContextThreshold &&
|
||
tokenCount > 20000;
|
||
const tokenCountThreshold = tokenCount > longContextThreshold;
|
||
if ((lastUsageThreshold || tokenCountThreshold) && Router.longContext) {
|
||
req.log.info(
|
||
`Using long context model due to token count: ${tokenCount}, threshold: ${longContextThreshold}`
|
||
);
|
||
return Router.longContext;
|
||
}
|
||
if (
|
||
req.body?.system?.length > 1 &&
|
||
req.body?.system[1]?.text?.startsWith("<CCR-SUBAGENT-MODEL>")
|
||
) {
|
||
const model = req.body?.system[1].text.match(
|
||
/<CCR-SUBAGENT-MODEL>(.*?)<\/CCR-SUBAGENT-MODEL>/s
|
||
);
|
||
if (model) {
|
||
req.body.system[1].text = req.body.system[1].text.replace(
|
||
`<CCR-SUBAGENT-MODEL>${model[1]}</CCR-SUBAGENT-MODEL>`,
|
||
""
|
||
);
|
||
return model[1];
|
||
}
|
||
}
|
||
// Use the background model for any Claude Haiku variant
|
||
if (
|
||
req.body.model?.includes("claude") &&
|
||
req.body.model?.includes("haiku") &&
|
||
config.Router.background
|
||
) {
|
||
req.log.info(`Using background model for ${req.body.model}`);
|
||
return config.Router.background;
|
||
}
|
||
// The priority of websearch must be higher than thinking.
|
||
if (
|
||
Array.isArray(req.body.tools) &&
|
||
req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) &&
|
||
Router.webSearch
|
||
) {
|
||
return Router.webSearch;
|
||
}
|
||
// if exits thinking, use the think model
|
||
if (req.body.thinking && Router.think) {
|
||
req.log.info(`Using think model for ${req.body.thinking}`);
|
||
return Router.think;
|
||
}
|
||
return Router!.default;
|
||
};
|
||
|
||
export const router = async (req: any, _res: any, context: any) => {
|
||
const { config, event } = context;
|
||
// Parse sessionId from metadata.user_id
|
||
if (req.body.metadata?.user_id) {
|
||
const parts = req.body.metadata.user_id.split("_session_");
|
||
if (parts.length > 1) {
|
||
req.sessionId = parts[1];
|
||
}
|
||
}
|
||
const lastMessageUsage = sessionUsageCache.get(req.sessionId);
|
||
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
|
||
if (
|
||
config.REWRITE_SYSTEM_PROMPT &&
|
||
system.length > 1 &&
|
||
system[1]?.text?.includes("<env>")
|
||
) {
|
||
const prompt = await readFile(config.REWRITE_SYSTEM_PROMPT, "utf-8");
|
||
system[1].text = `${prompt}<env>${system[1].text.split("<env>").pop()}`;
|
||
}
|
||
|
||
try {
|
||
const tokenCount = calculateTokenCount(
|
||
messages as MessageParam[],
|
||
system,
|
||
tools as Tool[]
|
||
);
|
||
|
||
let model;
|
||
if (config.CUSTOM_ROUTER_PATH) {
|
||
try {
|
||
const customRouter = require(config.CUSTOM_ROUTER_PATH);
|
||
req.tokenCount = tokenCount; // Pass token count to custom router
|
||
model = await customRouter(req, config, {
|
||
event,
|
||
});
|
||
} catch (e: any) {
|
||
req.log.error(`failed to load custom router: ${e.message}`);
|
||
}
|
||
}
|
||
if (!model) {
|
||
model = await getUseModel(req, tokenCount, config, lastMessageUsage);
|
||
}
|
||
req.body.model = model;
|
||
} catch (error: any) {
|
||
req.log.error(`Error in router middleware: ${error.message}`);
|
||
req.body.model = config.Router!.default;
|
||
}
|
||
return;
|
||
};
|
||
|
||
// 内存缓存,存储sessionId到项目名称的映射
|
||
// null值表示之前已查找过但未找到项目
|
||
// 使用LRU缓存,限制最大1000个条目
|
||
const sessionProjectCache = new LRUCache<string, string>({
|
||
max: 1000,
|
||
});
|
||
|
||
export const searchProjectBySession = async (
|
||
sessionId: string
|
||
): Promise<string | null> => {
|
||
// 首先检查缓存
|
||
if (sessionProjectCache.has(sessionId)) {
|
||
const result = sessionProjectCache.get(sessionId);
|
||
if (!result || result === '') {
|
||
return null;
|
||
}
|
||
return result;
|
||
}
|
||
|
||
try {
|
||
const dir = await opendir(CLAUDE_PROJECTS_DIR);
|
||
const folderNames: string[] = [];
|
||
|
||
// 收集所有文件夹名称
|
||
for await (const dirent of dir) {
|
||
if (dirent.isDirectory()) {
|
||
folderNames.push(dirent.name);
|
||
}
|
||
}
|
||
|
||
// 并发检查每个项目文件夹中是否存在sessionId.jsonl文件
|
||
const checkPromises = folderNames.map(async (folderName) => {
|
||
const sessionFilePath = join(
|
||
CLAUDE_PROJECTS_DIR,
|
||
folderName,
|
||
`${sessionId}.jsonl`
|
||
);
|
||
try {
|
||
const fileStat = await stat(sessionFilePath);
|
||
return fileStat.isFile() ? folderName : null;
|
||
} catch {
|
||
// 文件不存在,继续检查下一个
|
||
return null;
|
||
}
|
||
});
|
||
|
||
const results = await Promise.all(checkPromises);
|
||
|
||
// 返回第一个存在的项目目录名称
|
||
for (const result of results) {
|
||
if (result) {
|
||
// 缓存找到的结果
|
||
sessionProjectCache.set(sessionId, result);
|
||
return result;
|
||
}
|
||
}
|
||
|
||
// 缓存未找到的结果(null值表示之前已查找过但未找到项目)
|
||
sessionProjectCache.set(sessionId, '');
|
||
return null; // 没有找到匹配的项目
|
||
} catch (error) {
|
||
console.error("Error searching for project by session:", error);
|
||
// 出错时也缓存null结果,避免重复出错
|
||
sessionProjectCache.set(sessionId, '');
|
||
return null;
|
||
}
|
||
};
|