release v1.0.24 to support custom router
This commit is contained in:
@@ -1,9 +1,69 @@
|
||||
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
|
||||
import {
|
||||
MessageCreateParamsBase,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages";
|
||||
import { get_encoding } from "tiktoken";
|
||||
import { log } from "./log";
|
||||
|
||||
const enc = get_encoding("cl100k_base");
|
||||
|
||||
const calculateTokenCount = (
|
||||
messages: MessageParam[],
|
||||
system: any,
|
||||
tools: Tool[]
|
||||
) => {
|
||||
let tokenCount = 0;
|
||||
if (Array.isArray(messages)) {
|
||||
messages.forEach((message) => {
|
||||
if (typeof message.content === "string") {
|
||||
tokenCount += enc.encode(message.content).length;
|
||||
} else if (Array.isArray(message.content)) {
|
||||
message.content.forEach((contentPart: any) => {
|
||||
if (contentPart.type === "text") {
|
||||
tokenCount += enc.encode(contentPart.text).length;
|
||||
} else if (contentPart.type === "tool_use") {
|
||||
tokenCount += enc.encode(
|
||||
JSON.stringify(contentPart.input)
|
||||
).length;
|
||||
} else if (contentPart.type === "tool_result") {
|
||||
tokenCount += enc.encode(
|
||||
typeof contentPart.content === "string"
|
||||
? contentPart.content
|
||||
: JSON.stringify(contentPart.content)
|
||||
).length;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
if (typeof system === "string") {
|
||||
tokenCount += enc.encode(system).length;
|
||||
} else if (Array.isArray(system)) {
|
||||
system.forEach((item: any) => {
|
||||
if (item.type !== "text") return;
|
||||
if (typeof item.text === "string") {
|
||||
tokenCount += enc.encode(item.text).length;
|
||||
} else if (Array.isArray(item.text)) {
|
||||
item.text.forEach((textPart: any) => {
|
||||
tokenCount += enc.encode(textPart || "").length;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
if (tools) {
|
||||
tools.forEach((tool: Tool) => {
|
||||
if (tool.description) {
|
||||
tokenCount += enc.encode(tool.name + tool.description).length;
|
||||
}
|
||||
if (tool.input_schema) {
|
||||
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
|
||||
}
|
||||
});
|
||||
}
|
||||
return tokenCount;
|
||||
};
|
||||
|
||||
const getUseModel = (req: any, tokenCount: number, config: any) => {
|
||||
if (req.body.model.includes(",")) {
|
||||
return req.body.model;
|
||||
@@ -26,64 +86,37 @@ const getUseModel = (req: any, tokenCount: number, config: any) => {
|
||||
log("Using think model for ", req.body.thinking);
|
||||
return config.Router.think;
|
||||
}
|
||||
if (Array.isArray(req.body.tools) && req.body.tools.some(tool => tool.type?.startsWith('web_search')) && config.Router.webSearch) {
|
||||
if (
|
||||
Array.isArray(req.body.tools) &&
|
||||
req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) &&
|
||||
config.Router.webSearch
|
||||
) {
|
||||
return config.Router.webSearch;
|
||||
}
|
||||
return config.Router!.default;
|
||||
};
|
||||
|
||||
export const router = async (req: any, res: any, config: any) => {
|
||||
export const router = async (req: any, _res: any, config: any) => {
|
||||
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
|
||||
try {
|
||||
let tokenCount = 0;
|
||||
if (Array.isArray(messages)) {
|
||||
messages.forEach((message) => {
|
||||
if (typeof message.content === "string") {
|
||||
tokenCount += enc.encode(message.content).length;
|
||||
} else if (Array.isArray(message.content)) {
|
||||
message.content.forEach((contentPart) => {
|
||||
if (contentPart.type === "text") {
|
||||
tokenCount += enc.encode(contentPart.text).length;
|
||||
} else if (contentPart.type === "tool_use") {
|
||||
tokenCount += enc.encode(
|
||||
JSON.stringify(contentPart.input)
|
||||
).length;
|
||||
} else if (contentPart.type === "tool_result") {
|
||||
tokenCount += enc.encode(
|
||||
typeof contentPart.content === "string"
|
||||
? contentPart.content
|
||||
: JSON.stringify(contentPart.content)
|
||||
).length;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
const tokenCount = calculateTokenCount(
|
||||
messages as MessageParam[],
|
||||
system,
|
||||
tools as Tool[]
|
||||
);
|
||||
|
||||
let model;
|
||||
if (config.CUSTOM_ROUTER_PATH) {
|
||||
try {
|
||||
const customRouter = require(config.CUSTOM_ROUTER_PATH);
|
||||
model = await customRouter(req, config);
|
||||
} catch (e: any) {
|
||||
log("failed to load custom router", e.message);
|
||||
}
|
||||
}
|
||||
if (typeof system === "string") {
|
||||
tokenCount += enc.encode(system).length;
|
||||
} else if (Array.isArray(system)) {
|
||||
system.forEach((item) => {
|
||||
if (item.type !== "text") return;
|
||||
if (typeof item.text === "string") {
|
||||
tokenCount += enc.encode(item.text).length;
|
||||
} else if (Array.isArray(item.text)) {
|
||||
item.text.forEach((textPart) => {
|
||||
tokenCount += enc.encode(textPart || "").length;
|
||||
});
|
||||
}
|
||||
});
|
||||
if (!model) {
|
||||
model = getUseModel(req, tokenCount, config);
|
||||
}
|
||||
if (tools) {
|
||||
tools.forEach((tool) => {
|
||||
if (tool.description) {
|
||||
tokenCount += enc.encode(tool.name + tool.description).length;
|
||||
}
|
||||
if (tool.input_schema) {
|
||||
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
|
||||
}
|
||||
});
|
||||
}
|
||||
const model = getUseModel(req, tokenCount, config);
|
||||
req.body.model = model;
|
||||
} catch (error: any) {
|
||||
log("Error in router middleware:", error.message);
|
||||
|
||||
Reference in New Issue
Block a user