use the router to dispatch different models: background,longcontext and think

This commit is contained in:
jinhui.li
2025-06-14 19:48:29 +08:00
parent 7a5d712444
commit 9a89250d79
12 changed files with 300 additions and 145 deletions

View File

@@ -3,7 +3,7 @@ import { run } from "./index";
import { closeService } from "./utils/close";
import { showStatus } from "./utils/status";
import { executeCodeCommand } from "./utils/codeCommand";
import { isServiceRunning } from "./utils/processCheck";
import { cleanupPidFile, isServiceRunning } from "./utils/processCheck";
import { version } from "../package.json";
const command = process.argv[2];
@@ -44,6 +44,8 @@ async function waitForService(
}
import { spawn } from "child_process";
import { PID_FILE, REFERENCE_COUNT_FILE } from "./constants";
import { existsSync, readFileSync } from "fs";
async function main() {
switch (command) {
@@ -51,7 +53,26 @@ async function main() {
run();
break;
case "stop":
await closeService();
try {
const pid = parseInt(readFileSync(PID_FILE, "utf-8"));
process.kill(pid);
cleanupPidFile();
if (existsSync(REFERENCE_COUNT_FILE)) {
try {
require("fs").unlinkSync(REFERENCE_COUNT_FILE);
} catch (e) {
// Ignore cleanup errors
}
}
console.log(
"claude code router service has been successfully stopped."
);
} catch (e) {
console.log(
"Failed to stop the service. It may have already been stopped."
);
cleanupPidFile();
}
break;
case "status":
showStatus();

View File

@@ -9,6 +9,8 @@ export const PLUGINS_DIR = `${HOME_DIR}/plugins`;
export const PID_FILE = path.join(HOME_DIR, '.claude-code-router.pid');
export const REFERENCE_COUNT_FILE = '/tmp/claude-code-reference-count.txt';
export const DEFAULT_CONFIG = {
log: false,

View File

@@ -4,10 +4,16 @@ import { getOpenAICommonOptions, initConfig, initDir } from "./utils";
import { createServer } from "./server";
import { formatRequest } from "./middlewares/formatRequest";
import { rewriteBody } from "./middlewares/rewriteBody";
import { router } from "./middlewares/router";
import OpenAI from "openai";
import { streamOpenAIResponse } from "./utils/stream";
import { cleanupPidFile, isServiceRunning, savePid } from "./utils/processCheck";
import {
cleanupPidFile,
isServiceRunning,
savePid,
} from "./utils/processCheck";
import { LRUCache } from "lru-cache";
import { log } from "./utils/log";
async function initializeClaudeConfig() {
const homeDir = process.env.HOME;
@@ -33,9 +39,14 @@ interface RunOptions {
port?: number;
}
async function run(options: RunOptions = {}) {
const port = options.port || 3456;
interface ModelProvider {
name: string;
api_base_url: string;
api_key: string;
models: string[];
}
async function run(options: RunOptions = {}) {
// Check if service is already running
if (isServiceRunning()) {
console.log("✅ Service is already running in the background.");
@@ -44,20 +55,67 @@ async function run(options: RunOptions = {}) {
await initializeClaudeConfig();
await initDir();
await initConfig();
const config = await initConfig();
const Providers = new Map<string, ModelProvider>();
const providerCache = new LRUCache<string, OpenAI>({
max: 10,
ttl: 2 * 60 * 60 * 1000,
});
function getProviderInstance(providerName: string): OpenAI {
const provider: ModelProvider | undefined = Providers.get(providerName);
if (provider === undefined) {
throw new Error(`Provider ${providerName} not found`);
}
let openai = providerCache.get(provider.name);
if (!openai) {
openai = new OpenAI({
baseURL: provider.api_base_url,
apiKey: provider.api_key,
...getOpenAICommonOptions(),
});
providerCache.set(provider.name, openai);
}
return openai;
}
if (Array.isArray(config.Providers)) {
config.Providers.forEach((provider) => {
try {
Providers.set(provider.name, provider);
} catch (error) {
console.error("Failed to parse model provider:", error);
}
});
}
if (config.OPENAI_API_KEY && config.OPENAI_BASE_URL && config.OPENAI_MODEL) {
const defaultProvider = {
name: "default",
api_base_url: config.OPENAI_BASE_URL,
api_key: config.OPENAI_API_KEY,
models: [config.OPENAI_MODEL],
};
Providers.set("default", defaultProvider);
} else if (Providers.size > 0) {
const defaultProvider = Providers.values().next().value!;
Providers.set("default", defaultProvider);
}
const port = options.port || 3456;
// Save the PID of the background process
savePid(process.pid);
// Handle SIGINT (Ctrl+C) to clean up PID file
process.on('SIGINT', () => {
process.on("SIGINT", () => {
console.log("Received SIGINT, cleaning up...");
cleanupPidFile();
process.exit(0);
});
// Handle SIGTERM to clean up PID file
process.on('SIGTERM', () => {
process.on("SIGTERM", () => {
cleanupPidFile();
process.exit(0);
});
@@ -67,21 +125,27 @@ async function run(options: RunOptions = {}) {
? parseInt(process.env.SERVICE_PORT)
: port;
const server = createServer(servicePort);
server.useMiddleware(formatRequest);
server.useMiddleware(rewriteBody);
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
baseURL: process.env.OPENAI_BASE_URL,
...getOpenAICommonOptions(),
const server = await createServer(servicePort);
server.useMiddleware((req, res, next) => {
console.log("Middleware triggered for request:", req.body.model);
req.config = config;
next();
});
if (
config.Router?.background &&
config.Router?.think &&
config?.Router?.longContext
) {
log("Using custom router middleware");
server.useMiddleware(router);
}
server.useMiddleware(rewriteBody);
server.useMiddleware(formatRequest);
server.app.post("/v1/messages", async (req, res) => {
try {
if (process.env.OPENAI_MODEL) {
req.body.model = process.env.OPENAI_MODEL;
}
const completion: any = await openai.chat.completions.create(req.body);
const provider = getProviderInstance(req.provider || "default");
const completion: any = await provider.chat.completions.create(req.body);
await streamOpenAIResponse(res, completion, req.body.model, req.body);
} catch (e) {
console.error("Error in OpenAI API call:", e);
@@ -92,3 +156,4 @@ async function run(options: RunOptions = {}) {
}
export { run };
// run();

View File

@@ -1,5 +1,4 @@
import { Request, Response, NextFunction } from "express";
import { ContentBlockParam } from "@anthropic-ai/sdk/resources";
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
import OpenAI from "openai";
import { streamOpenAIResponse } from "../utils/stream";
@@ -181,6 +180,7 @@ export const formatRequest = async (
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
req.body = data;
console.log(JSON.stringify(data.messages, null, 2));
} catch (error) {
console.error("Error in request processing:", error);
const errorCompletion: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> =
@@ -189,7 +189,7 @@ export const formatRequest = async (
yield {
id: `error_${Date.now()}`,
created: Math.floor(Date.now() / 1000),
model: "gpt-3.5-turbo",
model,
object: "chat.completion.chunk",
choices: [
{

110
src/middlewares/router.ts Normal file
View File

@@ -0,0 +1,110 @@
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
import { Request, Response, NextFunction } from "express";
import { get_encoding } from "tiktoken";
import { log } from "../utils/log";
const enc = get_encoding("cl100k_base");
const getUseModel = (req: Request, tokenCount: number) => {
// if tokenCount is greater than 32K, use the long context model
if (tokenCount > 1000 * 32) {
log("Using long context model due to token count:", tokenCount);
const [provider, model] = req.config.Router!.longContext.split(",");
return {
provider,
model,
};
}
// If the model is claude-3-5-haiku, use the background model
if (req.body.model?.startsWith("claude-3-5-haiku")) {
log("Using background model for ", req.body.model);
const [provider, model] = req.config.Router!.background.split(",");
return {
provider,
model,
};
}
// if exits thinking, use the think model
if (req.body.thinking) {
log("Using think model for ", req.body.thinking);
const [provider, model] = req.config.Router!.think.split(",");
return {
provider,
model,
};
}
const [provider, model] = req.body.model.split(",");
if (provider && model) {
return {
provider,
model,
};
}
return {
provider: "default",
model: req.config.OPENAI_MODEL,
};
};
export const router = async (
req: Request,
res: Response,
next: NextFunction
) => {
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
try {
let tokenCount = 0;
if (Array.isArray(messages)) {
messages.forEach((message) => {
if (typeof message.content === "string") {
tokenCount += enc.encode(message.content).length;
} else if (Array.isArray(message.content)) {
message.content.forEach((contentPart) => {
if (contentPart.type === "text") {
tokenCount += enc.encode(contentPart.text).length;
} else if (contentPart.type === "tool_use") {
tokenCount += enc.encode(
JSON.stringify(contentPart.input)
).length;
} else if (contentPart.type === "tool_result") {
tokenCount += enc.encode(contentPart.content || "").length;
}
});
}
});
}
if (typeof system === "string") {
tokenCount += enc.encode(system).length;
} else if (Array.isArray(system)) {
system.forEach((item) => {
if (item.type !== "text") return;
if (typeof item.text === "string") {
tokenCount += enc.encode(item.text).length;
} else if (Array.isArray(item.text)) {
item.text.forEach((textPart) => {
tokenCount += enc.encode(textPart || "").length;
});
}
});
}
if (tools) {
tools.forEach((tool) => {
if (tool.description) {
tokenCount += enc.encode(tool.name + tool.description).length;
}
if (tool.input_schema) {
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
}
});
}
const { provider, model } = getUseModel(req, tokenCount);
req.provider = provider;
req.body.model = model;
} catch (error) {
log("Error in router middleware:", error.message);
req.provider = "default";
req.body.model = req.config.OPENAI_MODEL;
} finally {
next();
}
};

View File

@@ -6,7 +6,7 @@ interface Server {
start: () => void;
}
export const createServer = (port: number): Server => {
export const createServer = async (port: number): Promise<Server> => {
const app = express();
app.use(express.json({ limit: "500mb" }));
return {

View File

@@ -1,39 +1,42 @@
import { spawn } from 'child_process';
import { isServiceRunning, incrementReferenceCount, decrementReferenceCount } from './processCheck';
import { closeService } from './close';
import { spawn } from "child_process";
import {
incrementReferenceCount,
decrementReferenceCount,
} from "./processCheck";
import { closeService } from "./close";
export async function executeCodeCommand(args: string[] = []) {
// Service check is now handled in cli.ts
// Set environment variables
const env = {
...process.env,
DISABLE_PROMPT_CACHING: "1",
ANTHROPIC_AUTH_TOKEN: "test",
ANTHROPIC_BASE_URL: `http://127.0.0.1:3456`,
API_TIMEOUT_MS: "600000",
};
// Set environment variables
const env = {
...process.env,
DISABLE_PROMPT_CACHING: '1',
ANTHROPIC_AUTH_TOKEN: 'test',
ANTHROPIC_BASE_URL: 'http://127.0.0.1:3456',
API_TIMEOUT_MS: '600000'
};
// Increment reference count when command starts
incrementReferenceCount();
// Increment reference count when command starts
incrementReferenceCount();
// Execute claude command
const claudeProcess = spawn("claude", args, {
env,
stdio: "inherit",
shell: true,
});
// Execute claude command
const claudeProcess = spawn('claude', args, {
env,
stdio: 'inherit',
shell: true
});
claudeProcess.on("error", (error) => {
console.error("Failed to start claude command:", error.message);
console.log(
"Make sure Claude Code is installed: npm install -g @anthropic-ai/claude-code"
);
decrementReferenceCount();
process.exit(1);
});
claudeProcess.on('error', (error) => {
console.error('Failed to start claude command:', error.message);
console.log('Make sure Claude Code is installed: npm install -g @anthropic-ai/claude-code');
decrementReferenceCount();
process.exit(1);
});
claudeProcess.on('close', (code) => {
decrementReferenceCount();
closeService()
process.exit(code || 0);
});
claudeProcess.on("close", (code) => {
decrementReferenceCount();
closeService();
process.exit(code || 0);
});
}

View File

@@ -13,6 +13,8 @@ export function getOpenAICommonOptions(): ClientOptions {
const options: ClientOptions = {};
if (process.env.PROXY_URL) {
options.httpAgent = new HttpsProxyAgent(process.env.PROXY_URL);
} else if (process.env.HTTPS_PROXY) {
options.httpAgent = new HttpsProxyAgent(process.env.HTTPS_PROXY);
}
return options;
}
@@ -78,6 +80,7 @@ export const writeConfigFile = async (config: any) => {
export const initConfig = async () => {
const config = await readConfigFile();
Object.assign(process.env, config);
return config;
};
export const createClient = (options: ClientOptions) => {

View File

@@ -1,7 +1,5 @@
import { existsSync, readFileSync, writeFileSync } from 'fs';
import { PID_FILE } from '../constants';
const REFERENCE_COUNT_FILE = '/tmp/claude-code-reference-count.txt';
import { PID_FILE, REFERENCE_COUNT_FILE } from '../constants';
export function incrementReferenceCount() {
let count = 0;