fix preset error

This commit is contained in:
musistudio
2025-12-30 18:23:44 +08:00
parent 7e11fca0d5
commit 7400941ae8
9 changed files with 152 additions and 66 deletions

View File

@@ -39,11 +39,14 @@
"google-auth-library": "^10.1.0", "google-auth-library": "^10.1.0",
"json5": "^2.2.3", "json5": "^2.2.3",
"jsonrepair": "^3.13.0", "jsonrepair": "^3.13.0",
"lru-cache": "^11.2.2",
"openai": "^5.6.0", "openai": "^5.6.0",
"tiktoken": "^1.0.21",
"undici": "^7.10.0", "undici": "^7.10.0",
"uuid": "^11.1.0" "uuid": "^11.1.0"
}, },
"devDependencies": { "devDependencies": {
"@CCR/shared": "workspace:*",
"@types/node": "^24.0.15", "@types/node": "^24.0.15",
"esbuild": "^0.25.1", "esbuild": "^0.25.1",
"tsx": "^4.20.3", "tsx": "^4.20.3",

View File

@@ -10,7 +10,7 @@ const baseConfig: esbuild.BuildOptions = {
platform: "node", platform: "node",
target: "node18", target: "node18",
plugins: [], plugins: [],
external: ["fastify", "dotenv", "@fastify/cors", "undici"], external: ["fastify", "dotenv", "@fastify/cors", "undici", "tiktoken", "@CCR/shared", "lru-cache"],
}; };
const cjsConfig: esbuild.BuildOptions = { const cjsConfig: esbuild.BuildOptions = {

View File

@@ -30,6 +30,8 @@ import { errorHandler } from "./api/middleware";
import { registerApiRoutes } from "./api/routes"; import { registerApiRoutes } from "./api/routes";
import { ProviderService } from "./services/provider"; import { ProviderService } from "./services/provider";
import { TransformerService } from "./services/transformer"; import { TransformerService } from "./services/transformer";
import { router, calculateTokenCount, searchProjectBySession } from "./utils/router";
import { sessionUsageCache } from "./utils/cache";
// Extend FastifyRequest to include custom properties // Extend FastifyRequest to include custom properties
declare module "fastify" { declare module "fastify" {
@@ -125,6 +127,15 @@ class Server {
fastify.decorate('configService', this.configService); fastify.decorate('configService', this.configService);
fastify.decorate('transformerService', this.transformerService); fastify.decorate('transformerService', this.transformerService);
fastify.decorate('providerService', this.providerService); fastify.decorate('providerService', this.providerService);
// Add router hook for main namespace
fastify.addHook('preHandler', async (req: any, reply: any) => {
const url = new URL(`http://127.0.0.1${req.url}`);
if (url.pathname.endsWith("/v1/messages")) {
await router(req, reply, {
configService: this.configService,
});
}
});
await registerApiRoutes(fastify); await registerApiRoutes(fastify);
}); });
return return
@@ -133,6 +144,7 @@ class Server {
const configService = new ConfigService({ const configService = new ConfigService({
initialConfig: { initialConfig: {
providers: options.Providers, providers: options.Providers,
Router: options.Router,
} }
}); });
const transformerService = new TransformerService( const transformerService = new TransformerService(
@@ -145,15 +157,19 @@ class Server {
transformerService, transformerService,
this.app.log this.app.log
); );
// await this.app.register((fastify) => {
// fastify.decorate('configService', configService);
// fastify.decorate('transformerService', transformerService);
// fastify.decorate('providerService', providerService);
// }, { prefix: name });
await this.app.register(async (fastify) => { await this.app.register(async (fastify) => {
fastify.decorate('configService', configService); fastify.decorate('configService', configService);
fastify.decorate('transformerService', transformerService); fastify.decorate('transformerService', transformerService);
fastify.decorate('providerService', providerService); fastify.decorate('providerService', providerService);
// Add router hook for namespace
fastify.addHook('preHandler', async (req: any, reply: any) => {
const url = new URL(`http://127.0.0.1${req.url}`);
if (url.pathname.endsWith("/v1/messages")) {
await router(req, reply, {
configService,
});
}
});
await registerApiRoutes(fastify); await registerApiRoutes(fastify);
}, { prefix: name }); }, { prefix: name });
} }
@@ -174,6 +190,8 @@ class Server {
done(); done();
}); });
await this.registerNamespace('/')
this.app.addHook( this.app.addHook(
"preHandler", "preHandler",
async (req: FastifyRequest, reply: FastifyReply) => { async (req: FastifyRequest, reply: FastifyReply) => {
@@ -198,7 +216,6 @@ class Server {
} }
); );
await this.registerNamespace('/')
const address = await this.app.listen({ const address = await this.app.listen({
port: parseInt(this.configService.get("PORT") || "3000", 10), port: parseInt(this.configService.get("PORT") || "3000", 10),
@@ -224,3 +241,10 @@ class Server {
// Export for external use // Export for external use
export default Server; export default Server;
export { sessionUsageCache };
export { router };
export { calculateTokenCount };
export { searchProjectBySession };
export { ConfigService } from "./services/config";
export { ProviderService } from "./services/provider";
export { TransformerService } from "./services/transformer";

View File

@@ -1,10 +1,11 @@
import { get_encoding } from "tiktoken"; import { get_encoding } from "tiktoken";
import { sessionUsageCache, Usage } from "./cache"; import { sessionUsageCache, Usage } from "./cache";
import { readFile, access } from "fs/promises"; import { readFile } from "fs/promises";
import { opendir, stat } from "fs/promises"; import { opendir, stat } from "fs/promises";
import { join } from "path"; import { join } from "path";
import { CLAUDE_PROJECTS_DIR, HOME_DIR } from "@CCR/shared"; import { CLAUDE_PROJECTS_DIR, HOME_DIR } from "@CCR/shared";
import { LRUCache } from "lru-cache"; import { LRUCache } from "lru-cache";
import { ConfigService } from "../services/config";
// Types from @anthropic-ai/sdk // Types from @anthropic-ai/sdk
interface Tool { interface Tool {
@@ -86,17 +87,10 @@ export const calculateTokenCount = (
return tokenCount; return tokenCount;
}; };
const readConfigFile = async (filePath: string) => { const getProjectSpecificRouter = async (
try { req: any,
await access(filePath); configService: ConfigService
const content = await readFile(filePath, "utf8"); ) => {
return JSON.parse(content);
} catch (error) {
return null; // 文件不存在或读取失败时返回null
}
};
const getProjectSpecificRouter = async (req: any) => {
// 检查是否有项目特定的配置 // 检查是否有项目特定的配置
if (req.sessionId) { if (req.sessionId) {
const project = await searchProjectBySession(req.sessionId); const project = await searchProjectBySession(req.sessionId);
@@ -109,14 +103,18 @@ const getProjectSpecificRouter = async (req: any) => {
); );
// 首先尝试读取sessionConfig文件 // 首先尝试读取sessionConfig文件
const sessionConfig = await readConfigFile(sessionConfigPath); try {
if (sessionConfig && sessionConfig.Router) { const sessionConfig = JSON.parse(await readFile(sessionConfigPath, "utf8"));
return sessionConfig.Router; if (sessionConfig && sessionConfig.Router) {
} return sessionConfig.Router;
const projectConfig = await readConfigFile(projectConfigPath); }
if (projectConfig && projectConfig.Router) { } catch {}
return projectConfig.Router; try {
} const projectConfig = JSON.parse(await readFile(projectConfigPath, "utf8"));
if (projectConfig && projectConfig.Router) {
return projectConfig.Router;
}
} catch {}
} }
} }
return undefined; // 返回undefined表示使用原始配置 return undefined; // 返回undefined表示使用原始配置
@@ -125,15 +123,16 @@ const getProjectSpecificRouter = async (req: any) => {
const getUseModel = async ( const getUseModel = async (
req: any, req: any,
tokenCount: number, tokenCount: number,
config: any, configService: ConfigService,
lastUsage?: Usage | undefined lastUsage?: Usage | undefined
) => { ) => {
const projectSpecificRouter = await getProjectSpecificRouter(req); const projectSpecificRouter = await getProjectSpecificRouter(req, configService);
const Router = projectSpecificRouter || config.Router; const providers = configService.get<any[]>("providers") || [];
const Router = projectSpecificRouter || configService.get("Router");
if (req.body.model.includes(",")) { if (req.body.model.includes(",")) {
const [provider, model] = req.body.model.split(","); const [provider, model] = req.body.model.split(",");
const finalProvider = config.Providers.find( const finalProvider = providers.find(
(p: any) => p.name.toLowerCase() === provider (p: any) => p.name.toLowerCase() === provider
); );
const finalModel = finalProvider?.models?.find( const finalModel = finalProvider?.models?.find(
@@ -146,13 +145,13 @@ const getUseModel = async (
} }
// if tokenCount is greater than the configured threshold, use the long context model // if tokenCount is greater than the configured threshold, use the long context model
const longContextThreshold = Router.longContextThreshold || 60000; const longContextThreshold = Router?.longContextThreshold || 60000;
const lastUsageThreshold = const lastUsageThreshold =
lastUsage && lastUsage &&
lastUsage.input_tokens > longContextThreshold && lastUsage.input_tokens > longContextThreshold &&
tokenCount > 20000; tokenCount > 20000;
const tokenCountThreshold = tokenCount > longContextThreshold; const tokenCountThreshold = tokenCount > longContextThreshold;
if ((lastUsageThreshold || tokenCountThreshold) && Router.longContext) { if ((lastUsageThreshold || tokenCountThreshold) && Router?.longContext) {
req.log.info( req.log.info(
`Using long context model due to token count: ${tokenCount}, threshold: ${longContextThreshold}` `Using long context model due to token count: ${tokenCount}, threshold: ${longContextThreshold}`
); );
@@ -174,32 +173,38 @@ const getUseModel = async (
} }
} }
// Use the background model for any Claude Haiku variant // Use the background model for any Claude Haiku variant
const globalRouter = configService.get("Router");
if ( if (
req.body.model?.includes("claude") && req.body.model?.includes("claude") &&
req.body.model?.includes("haiku") && req.body.model?.includes("haiku") &&
config.Router.background globalRouter?.background
) { ) {
req.log.info(`Using background model for ${req.body.model}`); req.log.info(`Using background model for ${req.body.model}`);
return config.Router.background; return globalRouter.background;
} }
// The priority of websearch must be higher than thinking. // The priority of websearch must be higher than thinking.
if ( if (
Array.isArray(req.body.tools) && Array.isArray(req.body.tools) &&
req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) && req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) &&
Router.webSearch Router?.webSearch
) { ) {
return Router.webSearch; return Router.webSearch;
} }
// if exits thinking, use the think model // if exits thinking, use the think model
if (req.body.thinking && Router.think) { if (req.body.thinking && Router?.think) {
req.log.info(`Using think model for ${req.body.thinking}`); req.log.info(`Using think model for ${req.body.thinking}`);
return Router.think; return Router.think;
} }
return Router!.default; return Router?.default;
}; };
export const router = async (req: any, _res: any, context: any) => { export interface RouterContext {
const { config, event } = context; configService: ConfigService;
event?: any;
}
export const router = async (req: any, _res: any, context: RouterContext) => {
const { configService, event } = context;
// Parse sessionId from metadata.user_id // Parse sessionId from metadata.user_id
if (req.body.metadata?.user_id) { if (req.body.metadata?.user_id) {
const parts = req.body.metadata.user_id.split("_session_"); const parts = req.body.metadata.user_id.split("_session_");
@@ -209,12 +214,13 @@ export const router = async (req: any, _res: any, context: any) => {
} }
const lastMessageUsage = sessionUsageCache.get(req.sessionId); const lastMessageUsage = sessionUsageCache.get(req.sessionId);
const { messages, system = [], tools }: MessageCreateParamsBase = req.body; const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
const rewritePrompt = configService.get("REWRITE_SYSTEM_PROMPT");
if ( if (
config.REWRITE_SYSTEM_PROMPT && rewritePrompt &&
system.length > 1 && system.length > 1 &&
system[1]?.text?.includes("<env>") system[1]?.text?.includes("<env>")
) { ) {
const prompt = await readFile(config.REWRITE_SYSTEM_PROMPT, "utf-8"); const prompt = await readFile(rewritePrompt, "utf-8");
system[1].text = `${prompt}<env>${system[1].text.split("<env>").pop()}`; system[1].text = `${prompt}<env>${system[1].text.split("<env>").pop()}`;
} }
@@ -226,11 +232,12 @@ export const router = async (req: any, _res: any, context: any) => {
); );
let model; let model;
if (config.CUSTOM_ROUTER_PATH) { const customRouterPath = configService.get("CUSTOM_ROUTER_PATH");
if (customRouterPath) {
try { try {
const customRouter = require(config.CUSTOM_ROUTER_PATH); const customRouter = require(customRouterPath);
req.tokenCount = tokenCount; // Pass token count to custom router req.tokenCount = tokenCount; // Pass token count to custom router
model = await customRouter(req, config, { model = await customRouter(req, configService.getAll(), {
event, event,
}); });
} catch (e: any) { } catch (e: any) {
@@ -238,12 +245,13 @@ export const router = async (req: any, _res: any, context: any) => {
} }
} }
if (!model) { if (!model) {
model = await getUseModel(req, tokenCount, config, lastMessageUsage); model = await getUseModel(req, tokenCount, configService, lastMessageUsage);
} }
req.body.model = model; req.body.model = model;
} catch (error: any) { } catch (error: any) {
req.log.error(`Error in router middleware: ${error.message}`); req.log.error(`Error in router middleware: ${error.message}`);
req.body.model = config.Router!.default; const Router = configService.get("Router");
req.body.model = Router?.default;
} }
return; return;
}; };

View File

@@ -4,14 +4,13 @@ import { homedir } from "os";
import { join } from "path"; import { join } from "path";
import { initConfig, initDir } from "./utils"; import { initConfig, initDir } from "./utils";
import { createServer } from "./server"; import { createServer } from "./server";
import { router } from "./utils/router";
import { apiKeyAuth } from "./middleware/auth"; import { apiKeyAuth } from "./middleware/auth";
import {CONFIG_FILE, HOME_DIR, listPresets} from "@CCR/shared"; import { CONFIG_FILE, HOME_DIR, listPresets } from "@CCR/shared";
import { createStream } from 'rotating-file-stream'; import { createStream } from 'rotating-file-stream';
import { sessionUsageCache } from "./utils/cache"; import { sessionUsageCache } from "@musistudio/llms";
import {SSEParserTransform} from "./utils/SSEParser.transform"; import { SSEParserTransform } from "./utils/SSEParser.transform";
import {SSESerializerTransform} from "./utils/SSESerializer.transform"; import { SSESerializerTransform } from "./utils/SSESerializer.transform";
import {rewriteStream} from "./utils/rewriteStream"; import { rewriteStream } from "./utils/rewriteStream";
import JSON5 from "json5"; import JSON5 from "json5";
import { IAgent, ITool } from "./agents/type"; import { IAgent, ITool } from "./agents/type";
import agentsManager from "./agents"; import agentsManager from "./agents";
@@ -138,10 +137,9 @@ async function getServer(options: RunOptions = {}) {
logger: loggerConfig, logger: loggerConfig,
}); });
presets.forEach(preset => { await Promise.allSettled(
console.log(preset.name, preset.config); presets.map(async preset => await serverInstance.registerNamespace(preset.name, preset.config))
serverInstance.registerNamespace(preset.name, preset.config); )
})
// Add async preHandler hook for authentication // Add async preHandler hook for authentication
serverInstance.addHook("preHandler", async (req: any, reply: any) => { serverInstance.addHook("preHandler", async (req: any, reply: any) => {
@@ -155,7 +153,15 @@ async function getServer(options: RunOptions = {}) {
}); });
}); });
serverInstance.addHook("preHandler", async (req: any, reply: any) => { serverInstance.addHook("preHandler", async (req: any, reply: any) => {
if (req.url.startsWith("/v1/messages") && !req.url.startsWith("/v1/messages/count_tokens")) { const url = new URL(`http://127.0.0.1${req.url}`);
req.pathname = url.pathname;
if (req.pathname.endsWith("/v1/messages") && req.pathname !== "/v1/messages") {
req.preset = req.pathname.replace("/v1/messages", "").replace("/", "");
}
})
serverInstance.addHook("preHandler", async (req: any, reply: any) => {
if (req.pathname.endsWith("/v1/messages")) {
const useAgents = [] const useAgents = []
for (const agent of agentsManager.getAllAgents()) { for (const agent of agentsManager.getAllAgents()) {
@@ -185,17 +191,13 @@ async function getServer(options: RunOptions = {}) {
if (useAgents.length) { if (useAgents.length) {
req.agents = useAgents; req.agents = useAgents;
} }
await router(req, reply, {
config,
event
});
} }
}); });
serverInstance.addHook("onError", async (request: any, reply: any, error: any) => { serverInstance.addHook("onError", async (request: any, reply: any, error: any) => {
event.emit('onError', request, reply, error); event.emit('onError', request, reply, error);
}) })
serverInstance.addHook("onSend", (req: any, reply: any, payload: any, done: any) => { serverInstance.addHook("onSend", (req: any, reply: any, payload: any, done: any) => {
if (req.sessionId && req.url.startsWith("/v1/messages") && !req.url.startsWith("/v1/messages/count_tokens")) { if (req.sessionId && req.pathname.endsWith("/v1/messages")) {
if (payload instanceof ReadableStream) { if (payload instanceof ReadableStream) {
if (req.agents) { if (req.agents) {
const abortController = new AbortController(); const abortController = new AbortController();

View File

@@ -1,10 +1,9 @@
import Server from "@musistudio/llms"; import Server, { calculateTokenCount } from "@musistudio/llms";
import { readConfigFile, writeConfigFile, backupConfigFile } from "./utils"; import { readConfigFile, writeConfigFile, backupConfigFile } from "./utils";
import { join } from "path"; import { join } from "path";
import fastifyStatic from "@fastify/static"; import fastifyStatic from "@fastify/static";
import { readdirSync, statSync, readFileSync, writeFileSync, existsSync, mkdirSync, unlinkSync, rmSync } from "fs"; import { readdirSync, statSync, readFileSync, writeFileSync, existsSync, mkdirSync, unlinkSync, rmSync } from "fs";
import { homedir } from "os"; import { homedir } from "os";
import { calculateTokenCount } from "./utils/router";
import { import {
getPresetDir, getPresetDir,
readManifestFromDir, readManifestFromDir,

View File

@@ -1,5 +1,6 @@
declare module "@musistudio/llms" { declare module "@musistudio/llms" {
import { FastifyInstance } from "fastify"; import { FastifyInstance } from "fastify";
import { FastifyBaseLogger } from "fastify";
export interface ServerConfig { export interface ServerConfig {
jsonPath?: string; jsonPath?: string;
@@ -9,7 +10,7 @@ declare module "@musistudio/llms" {
export interface Server { export interface Server {
app: FastifyInstance; app: FastifyInstance;
logger: any; logger: FastifyBaseLogger;
start(): Promise<void>; start(): Promise<void>;
} }
@@ -18,4 +19,44 @@ declare module "@musistudio/llms" {
}; };
export default Server; export default Server;
// Export cache
export interface Usage {
input_tokens: number;
output_tokens: number;
}
export const sessionUsageCache: any;
// Export router
export interface RouterContext {
configService: any;
event?: any;
}
export const router: (req: any, res: any, context: RouterContext) => Promise<void>;
// Export utilities
export const calculateTokenCount: (messages: any[], system: any, tools: any[]) => number;
export const searchProjectBySession: (sessionId: string) => Promise<string | null>;
// Export services
export class ConfigService {
constructor(options?: any);
get<T = any>(key: string): T | undefined;
get<T = any>(key: string, defaultValue: T): T;
getAll(): any;
has(key: string): boolean;
set(key: string, value: any): void;
reload(): void;
}
export class ProviderService {
constructor(configService: any, transformerService: any, logger: any);
}
export class TransformerService {
constructor(configService: any, logger: any);
initialize(): Promise<void>;
}
} }

9
pnpm-lock.yaml generated
View File

@@ -133,9 +133,15 @@ importers:
jsonrepair: jsonrepair:
specifier: ^3.13.0 specifier: ^3.13.0
version: 3.13.1 version: 3.13.1
lru-cache:
specifier: ^11.2.2
version: 11.2.2
openai: openai:
specifier: ^5.6.0 specifier: ^5.6.0
version: 5.23.2(ws@8.18.3) version: 5.23.2(ws@8.18.3)
tiktoken:
specifier: ^1.0.21
version: 1.0.22
undici: undici:
specifier: ^7.10.0 specifier: ^7.10.0
version: 7.16.0 version: 7.16.0
@@ -143,6 +149,9 @@ importers:
specifier: ^11.1.0 specifier: ^11.1.0
version: 11.1.0 version: 11.1.0
devDependencies: devDependencies:
'@CCR/shared':
specifier: workspace:*
version: link:../shared
'@types/node': '@types/node':
specifier: ^24.0.15 specifier: ^24.0.15
version: 24.7.0 version: 24.7.0