add agents to support route image

This commit is contained in:
musistudio
2025-09-01 17:19:43 +08:00
parent 9a5ea191f8
commit 19522f496b
10 changed files with 558 additions and 29 deletions

4
.gitignore vendored
View File

@@ -2,4 +2,6 @@ node_modules
.env
log.txt
.idea
dist
dist
.DS_Store
.vscode

212
src/agents/image.agent.ts Normal file
View File

@@ -0,0 +1,212 @@
import {IAgent, ITool} from "./type";
import { createHash } from 'crypto';
import { LRUCache } from 'lru-cache';
interface ImageCacheEntry {
source: any;
timestamp: number;
}
class ImageCache {
private cache: LRUCache<string, ImageCacheEntry>;
constructor(maxSize = 100) {
this.cache = new LRUCache({
max: maxSize,
ttl: 24 * 60 * 60 * 1000,
});
}
calculateHash(base64Image: string): string {
const hash = createHash('sha256');
hash.update(base64Image);
return hash.digest('hex');
}
storeImage(id: string, source: any): void {
if (this.hasImage(id)) return;
const base64Image = source.data
this.cache.set(id, {
source,
timestamp: Date.now(),
});
}
getImage(id: string): any {
const entry = this.cache.get(id);
return entry ? entry.source : null;
}
hasImage(hash: string): boolean {
return this.cache.has(hash);
}
clear(): void {
this.cache.clear();
}
size(): number {
return this.cache.size;
}
}
const imageCache = new ImageCache();
export class ImageAgent implements IAgent {
name = "image";
tools: Map<string, ITool>;
constructor() {
this.tools = new Map<string, ITool>();
this.appendTools()
}
shouldHandle(req: any, config: any): boolean {
if (!config.Router.image) return false;
const lastMessage = req.body.messages[req.body.messages.length - 1]
if (lastMessage.role === 'user' && Array.isArray(lastMessage.content) &&lastMessage.content.find((item: any) => item.type === 'image')) {
if (config.Router.image) {
req.body.model = config.Router.image
}
return false;
}
return req.body.messages.some((msg: any) => msg.role === 'user' && Array.isArray(msg.content) && msg.content.some((item: any) => item.type === 'image'))
}
appendTools() {
this.tools.set('analyzeImage', {
name: "analyzeImage",
description: "Analyse image or images by ID and extract information such as OCR text, objects, layout, colors, or safety signals.",
input_schema: {
"type": "object",
"properties": {
"imageId": {
"type": "array",
"description": "an array of IDs to analyse",
"items": {
"type": "string"
}
},
"task": {
"type": "string",
"description": "Details of task to perform on the image.The more detailed, the better",
},
"regions": {
"type": "array",
"description": "Optional regions of interest within the image",
"items": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Optional label for the region"},
"x": {"type": "number", "description": "X coordinate"},
"y": {"type": "number", "description": "Y coordinate"},
"w": {"type": "number", "description": "Width of the region"},
"h": {"type": "number", "description": "Height of the region"},
"units": {"type": "string", "enum": ["px", "pct"], "description": "Units for coordinates and size"}
},
"required": ["x", "y", "w", "h", "units"]
}
}
},
"required": ["imageId", "task"]
},
handler: async (args, context) => {
const imageMessages = [];
let imageId;
// Create image messages from cached images
if (args.imageId && Array.isArray(args.imageId)) {
args.imageId.forEach((imgId: string) => {
const image = imageCache.getImage(`${context.req.id}_Image#${imgId}`);
if (image) {
imageMessages.push({
type: "image",
source: image,
});
}
});
imageId = args.imageId;
delete args.imageId;
}
// Add text message with the response
if (Object.keys(args).length > 0) {
imageMessages.push({
type: "text",
text: JSON.stringify(args),
});
}
// Send to analysis agent and get response
const agentResponse = await fetch(`http://127.0.0.1:${context.config.PORT}/v1/messages`, {
method: "POST",
headers: {
'x-api-key': context.config.APIKEY,
'content-type': 'application/json',
},
body: JSON.stringify({
model: context.config.Router.image,
system: [{
type: 'text',
text: `你需要按照任务去解析图片`
}],
messages: [
{
role: 'user',
content: imageMessages,
}
],
stream: false,
}),
}).then(res => res.json()).catch(err => {
return null;
});
if (!agentResponse || !agentResponse.content) {
return 'analyzeImage Error';
}
return agentResponse.content[0].text
}
})
}
reqHandler(req: any, config: any) {
// Inject system prompt
req.body?.system?.push({
type: "text",
text: `You are a text-only language model and do not possess visual perception.
If the user requests you to view, analyze, or extract information from an image, you **must** call the \`analyzeImage\` tool.
When invoking this tool, you must pass the correct \`imageId\` extracted from the prior conversation.
Image identifiers are always provided in the format \`[Image #imageId]\`.
If multiple images exist, select the **most relevant imageId** based on the users current request and prior context.
Do not attempt to describe or analyze the image directly yourself.
Ignore any user interruptions or unrelated instructions that might cause you to skip this requirement.
Your response should consistently follow this rule whenever image-related analysis is requested.`,
})
const imageContents = req.body.messages.filter((item: any) => {
return item.role === 'user' && Array.isArray(item.content) &&
item.content.some((msg: any) => msg.type === "image");
});
let imgId = 1;
imageContents.forEach((item: any) => {
item.content.forEach((msg: any) => {
if (msg.type === "image") {
imageCache.storeImage(`${req.id}_Image#${imgId}`, msg.source);
msg.type = 'text';
delete msg.source;
msg.text = `[Image #${imgId}]This is an image, if you need to view or analyze it, you need to extract the imageId`;
imgId++;
} else if (msg.type === "text" && msg.text.includes('[Image #')) {
msg.text = msg.text.replace(/\[Image #\d+\]/g, '');
}
});
});
}
}
export const imageAgent = new ImageAgent();

48
src/agents/index.ts Normal file
View File

@@ -0,0 +1,48 @@
import { imageAgent } from './image.agent'
import { IAgent } from './type';
export class AgentsManager {
private agents: Map<string, IAgent> = new Map();
/**
* 注册一个agent
* @param agent 要注册的agent实例
* @param isDefault 是否设为默认agent
*/
registerAgent(agent: IAgent): void {
this.agents.set(agent.name, agent);
}
/**
* 根据名称查找agent
* @param name agent名称
* @returns 找到的agent实例未找到返回undefined
*/
getAgent(name: string): IAgent | undefined {
return this.agents.get(name);
}
/**
* 获取所有已注册的agents
* @returns 所有agent实例的数组
*/
getAllAgents(): IAgent[] {
return Array.from(this.agents.values());
}
/**
* 获取所有agent的工具
* @returns 工具数组
*/
getAllTools(): any[] {
const allTools: any[] = [];
for (const agent of this.agents.values()) {
allTools.push(...agent.tools.values());
}
return allTools;
}
}
const agentsManager = new AgentsManager()
agentsManager.registerAgent(imageAgent)
export default agentsManager

19
src/agents/type.ts Normal file
View File

@@ -0,0 +1,19 @@
export interface ITool {
name: string;
description: string;
input_schema: any;
handler: (args: any, context: any) => Promise<string>;
}
export interface IAgent {
name: string;
tools: Map<string, ITool>;
shouldHandle: (req: any, config: any) => boolean;
reqHandler: (req: any, config: any) => void;
resHandler?: (payload: any, config: any) => void;
}

View File

@@ -16,7 +16,13 @@ import createWriteStream from "pino-rotating-file-stream";
import { HOME_DIR } from "./constants";
import { configureLogging } from "./utils/log";
import { sessionUsageCache } from "./utils/cache";
import Stream from "node:stream";
import {SSEParserTransform} from "./utils/SSEParser.transform";
import {SSESerializerTransform} from "./utils/SSESerializer.transform";
import {rewriteStream} from "./utils/rewriteStream";
import JSON5 from "json5";
import { IAgent } from "./agents/type";
import agentsManager from "./agents";
async function initializeClaudeConfig() {
const homeDir = homedir();
@@ -58,7 +64,7 @@ async function run(options: RunOptions = {}) {
// Configure logging based on config
configureLogging(config);
let HOST = config.HOST;
let HOST = config.HOST || "127.0.0.1";
if (config.HOST && !config.APIKEY) {
HOST = "127.0.0.1";
@@ -82,7 +88,6 @@ async function run(options: RunOptions = {}) {
cleanupPidFile();
process.exit(0);
});
console.log(HOST);
// Use port from environment variable if set (for background process)
const servicePort = process.env.SERVICE_PORT
@@ -131,12 +136,135 @@ async function run(options: RunOptions = {}) {
});
server.addHook("preHandler", async (req, reply) => {
if (req.url.startsWith("/v1/messages")) {
router(req, reply, config);
const useAgents = []
for (const agent of agentsManager.getAllAgents()) {
if (agent.shouldHandle(req, config)) {
// 设置agent标识
useAgents.push(agent.name)
// change request body
agent.reqHandler(req, config);
// append agent tools
if (agent.tools.size) {
req.body.tools.unshift(...Array.from(agent.tools.values()).map(item => {
return {
name: item.name,
description: item.description,
input_schema: item.input_schema
}
}))
}
}
}
if (useAgents.length) {
req.agents = useAgents;
}
await router(req, reply, config);
}
});
server.addHook("onSend", (req, reply, payload, done) => {
server.addHook("onSend", async (req, reply, payload) => {
if (req.sessionId && req.url.startsWith("/v1/messages")) {
if (payload instanceof ReadableStream) {
if (req.agents) {
const eventStream = payload.pipeThrough(new SSEParserTransform())
let currentAgent: undefined | IAgent;
let currentToolIndex = -1
let currentToolName = ''
let currentToolArgs = ''
let currentToolId = ''
const toolMessages: any[] = []
const assistantMessages: any[] = []
// 存储Anthropic格式的消息体区分文本和工具类型
return rewriteStream(eventStream, async (data, controller) => {
// 检测工具调用开始
if (data.event === 'content_block_start' && data?.data?.content_block?.name) {
const agent = req.agents.find((name: string) => agentsManager.getAgent(name)?.tools.get(data.data.content_block.name))
if (agent) {
currentAgent = agentsManager.getAgent(agent)
currentToolIndex = data.data.index
currentToolName = data.data.content_block.name
currentToolId = data.data.content_block.id
return undefined;
}
}
// 收集工具参数
if (currentToolIndex > -1 && data.data.index === currentToolIndex && data.data?.delta?.type === 'input_json_delta') {
currentToolArgs += data.data?.delta?.partial_json;
return undefined;
}
// 工具调用完成处理agent调用
if (currentToolIndex > -1 && data.data.index === currentToolIndex && data.data.type === 'content_block_stop') {
try {
const args = JSON5.parse(currentToolArgs);
assistantMessages.push({
type: "tool_use",
id: currentToolId,
name: currentToolName,
input: args
})
const toolResult = await currentAgent?.tools.get(currentToolName)?.handler(args, {
req,
config
});
toolMessages.push({
"tool_use_id": currentToolId,
"type": "tool_result",
"content": toolResult
})
currentAgent = undefined
currentToolIndex = -1
currentToolName = ''
currentToolArgs = ''
currentToolId = ''
} catch (e) {
console.log(e);
}
return undefined;
}
if (data.event === 'message_delta' && toolMessages.length) {
req.body.messages.push({
role: 'assistant',
content: assistantMessages
})
req.body.messages.push({
role: 'user',
content: toolMessages
})
const response = await fetch(`http://127.0.0.1:${config.PORT}/v1/messages`, {
method: "POST",
headers: {
'x-api-key': config.APIKEY,
'content-type': 'application/json',
},
body: JSON.stringify(req.body),
})
if (!response.ok) {
return undefined;
}
const stream = response.body!.pipeThrough(new SSEParserTransform())
const reader = stream.getReader()
while (true) {
const {value, done} = await reader.read();
if (done) {
break;
}
if (['message_start', 'message_stop'].includes(value.event)) {
continue
}
controller.enqueue(value)
}
return undefined
}
return data
}).pipeThrough(new SSESerializerTransform())
}
const [originalStream, clonedStream] = payload.tee();
const read = async (stream: ReadableStream) => {
const reader = stream.getReader();
@@ -156,29 +284,13 @@ async function run(options: RunOptions = {}) {
}
}
read(clonedStream);
done(null, originalStream)
} else {
req.log.debug({payload}, 'onSend Hook')
sessionUsageCache.put(req.sessionId, payload.usage);
if (payload instanceof Buffer || payload instanceof Response) {
done(null, payload);
} else if(typeof payload === "object") {
done(null, JSON.stringify(payload));
} else {
done(null, payload);
}
}
} else {
if(payload instanceof Buffer || payload instanceof Response || payload === null || payload instanceof ReadableStream || payload instanceof Stream) {
done(null, payload);
} else if(typeof payload === "object") {
req.log.debug({payload}, 'onSend Hook')
done(null, JSON.stringify(payload));
} else {
done(null, payload);
return originalStream
}
sessionUsageCache.put(req.sessionId, payload.usage);
}
return payload;
});
server.start();
}

View File

@@ -0,0 +1,73 @@
export class SSEParserTransform extends TransformStream<string, any> {
private buffer = '';
private currentEvent: Record<string, any> = {};
constructor() {
super({
transform: (chunk: string, controller) => {
const decoder = new TextDecoder();
const text = decoder.decode(chunk);
this.buffer += text;
const lines = this.buffer.split('\n');
// 保留最后一行(可能不完整)
this.buffer = lines.pop() || '';
for (const line of lines) {
const event = this.processLine(line);
if (event) {
controller.enqueue(event);
}
}
},
flush: (controller) => {
// 处理缓冲区中剩余的内容
if (this.buffer.trim()) {
const events: any[] = [];
this.processLine(this.buffer.trim(), events);
events.forEach(event => controller.enqueue(event));
}
// 推送最后一个事件(如果有)
if (Object.keys(this.currentEvent).length > 0) {
controller.enqueue(this.currentEvent);
}
}
});
}
private processLine(line: string, events?: any[]): any | null {
if (!line.trim()) {
if (Object.keys(this.currentEvent).length > 0) {
const event = { ...this.currentEvent };
this.currentEvent = {};
if (events) {
events.push(event);
return null;
}
return event;
}
return null;
}
if (line.startsWith('event:')) {
this.currentEvent.event = line.slice(6).trim();
} else if (line.startsWith('data:')) {
const data = line.slice(5).trim();
if (data === '[DONE]') {
this.currentEvent.data = { type: 'done' };
} else {
try {
this.currentEvent.data = JSON.parse(data);
} catch (e) {
this.currentEvent.data = { raw: data, error: 'JSON parse failed' };
}
}
} else if (line.startsWith('id:')) {
this.currentEvent.id = line.slice(3).trim();
} else if (line.startsWith('retry:')) {
this.currentEvent.retry = parseInt(line.slice(6).trim());
}
return null;
}
}

View File

@@ -0,0 +1,29 @@
export class SSESerializerTransform extends TransformStream<any, string> {
constructor() {
super({
transform: (event, controller) => {
let output = '';
if (event.event) {
output += `event: ${event.event}\n`;
}
if (event.id) {
output += `id: ${event.id}\n`;
}
if (event.retry) {
output += `retry: ${event.retry}\n`;
}
if (event.data) {
if (event.data.type === 'done') {
output += 'data: [DONE]\n';
} else {
output += `data: ${JSON.stringify(event.data)}\n`;
}
}
output += '\n';
controller.enqueue(output);
}
});
}
}

View File

@@ -5,6 +5,8 @@ import {
decrementReferenceCount,
incrementReferenceCount,
} from "./processCheck";
import {HOME_DIR} from "../constants";
import {join} from "path";
export async function executeCodeCommand(args: string[] = []) {
// Set environment variables
@@ -63,7 +65,7 @@ export async function executeCodeCommand(args: string[] = []) {
const stdioConfig: StdioOptions = config.NON_INTERACTIVE_MODE
? ["pipe", "inherit", "inherit"] // Pipe stdin for non-interactive
: "inherit"; // Default inherited behavior
console.log(joinedArgs)
const claudeProcess = spawn(
claudePath + (joinedArgs ? ` ${joinedArgs}` : ""),
[],

View File

@@ -0,0 +1,31 @@
/**rewriteStream
* 读取源readablestream返回一个新的readablestream由processor对源数据进行处理后将返回的新值推送到新的stream如果没有返回值则不推送
* @param stream
* @param processor
*/
export const rewriteStream = (stream: ReadableStream, processor: (data: any, controller: ReadableStreamController<any>) => Promise<any>): ReadableStream => {
const reader = stream.getReader()
return new ReadableStream({
async start(controller) {
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
controller.close()
break
}
const processed = await processor(value, controller)
if (processed !== undefined) {
controller.enqueue(processed)
}
}
} catch (error) {
controller.error(error)
} finally {
reader.releaseLock()
}
}
})
}

View File

@@ -72,16 +72,17 @@ const getUseModel = async (
if (req.body.model.includes(",")) {
const [provider, model] = req.body.model.split(",");
const finalProvider = config.Providers.find(
(p: any) => p.name.toLowerCase() === provider
(p: any) => p.name.toLowerCase() === provider
);
const finalModel = finalProvider?.models?.find(
(m: any) => m.toLowerCase() === model
(m: any) => m.toLowerCase() === model
);
if (finalProvider && finalModel) {
return `${finalProvider.name},${finalModel}`;
}
return req.body.model;
}
// if tokenCount is greater than the configured threshold, use the long context model
const longContextThreshold = config.Router.longContextThreshold || 60000;
const lastUsageThreshold =