diff --git a/.gitignore b/.gitignore index 57438d5..6e6d084 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ node_modules .env log.txt .idea -dist \ No newline at end of file +dist +.DS_Store +.vscode \ No newline at end of file diff --git a/src/agents/image.agent.ts b/src/agents/image.agent.ts new file mode 100644 index 0000000..de4adee --- /dev/null +++ b/src/agents/image.agent.ts @@ -0,0 +1,212 @@ +import {IAgent, ITool} from "./type"; +import { createHash } from 'crypto'; +import { LRUCache } from 'lru-cache'; + +interface ImageCacheEntry { + source: any; + timestamp: number; +} + +class ImageCache { + private cache: LRUCache; + + constructor(maxSize = 100) { + this.cache = new LRUCache({ + max: maxSize, + ttl: 24 * 60 * 60 * 1000, + }); + } + + calculateHash(base64Image: string): string { + const hash = createHash('sha256'); + hash.update(base64Image); + return hash.digest('hex'); + } + + storeImage(id: string, source: any): void { + if (this.hasImage(id)) return; + const base64Image = source.data + this.cache.set(id, { + source, + timestamp: Date.now(), + }); + } + + getImage(id: string): any { + const entry = this.cache.get(id); + return entry ? entry.source : null; + } + + hasImage(hash: string): boolean { + return this.cache.has(hash); + } + + clear(): void { + this.cache.clear(); + } + + size(): number { + return this.cache.size; + } +} + +const imageCache = new ImageCache(); + +export class ImageAgent implements IAgent { + name = "image"; + tools: Map; + + constructor() { + this.tools = new Map(); + this.appendTools() + } + + shouldHandle(req: any, config: any): boolean { + if (!config.Router.image) return false; + const lastMessage = req.body.messages[req.body.messages.length - 1] + if (lastMessage.role === 'user' && Array.isArray(lastMessage.content) &&lastMessage.content.find((item: any) => item.type === 'image')) { + if (config.Router.image) { + req.body.model = config.Router.image + } + return false; + } + return req.body.messages.some((msg: any) => msg.role === 'user' && Array.isArray(msg.content) && msg.content.some((item: any) => item.type === 'image')) + } + + appendTools() { + this.tools.set('analyzeImage', { + name: "analyzeImage", + description: "Analyse image or images by ID and extract information such as OCR text, objects, layout, colors, or safety signals.", + input_schema: { + "type": "object", + "properties": { + "imageId": { + "type": "array", + "description": "an array of IDs to analyse", + "items": { + "type": "string" + } + }, + "task": { + "type": "string", + "description": "Details of task to perform on the image.The more detailed, the better", + }, + "regions": { + "type": "array", + "description": "Optional regions of interest within the image", + "items": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Optional label for the region"}, + "x": {"type": "number", "description": "X coordinate"}, + "y": {"type": "number", "description": "Y coordinate"}, + "w": {"type": "number", "description": "Width of the region"}, + "h": {"type": "number", "description": "Height of the region"}, + "units": {"type": "string", "enum": ["px", "pct"], "description": "Units for coordinates and size"} + }, + "required": ["x", "y", "w", "h", "units"] + } + } + }, + "required": ["imageId", "task"] + }, + handler: async (args, context) => { + const imageMessages = []; + let imageId; + + // Create image messages from cached images + if (args.imageId && Array.isArray(args.imageId)) { + args.imageId.forEach((imgId: string) => { + const image = imageCache.getImage(`${context.req.id}_Image#${imgId}`); + if (image) { + imageMessages.push({ + type: "image", + source: image, + }); + } + }); + imageId = args.imageId; + delete args.imageId; + } + + // Add text message with the response + if (Object.keys(args).length > 0) { + imageMessages.push({ + type: "text", + text: JSON.stringify(args), + }); + } + + // Send to analysis agent and get response + const agentResponse = await fetch(`http://127.0.0.1:${context.config.PORT}/v1/messages`, { + method: "POST", + headers: { + 'x-api-key': context.config.APIKEY, + 'content-type': 'application/json', + }, + body: JSON.stringify({ + model: context.config.Router.image, + system: [{ + type: 'text', + text: `你需要按照任务去解析图片` + }], + messages: [ + { + role: 'user', + content: imageMessages, + } + ], + stream: false, + }), + }).then(res => res.json()).catch(err => { + return null; + }); + if (!agentResponse || !agentResponse.content) { + return 'analyzeImage Error'; + } + return agentResponse.content[0].text + } + }) + } + + reqHandler(req: any, config: any) { + // Inject system prompt + req.body?.system?.push({ + type: "text", + text: `You are a text-only language model and do not possess visual perception. +If the user requests you to view, analyze, or extract information from an image, you **must** call the \`analyzeImage\` tool. + +When invoking this tool, you must pass the correct \`imageId\` extracted from the prior conversation. +Image identifiers are always provided in the format \`[Image #imageId]\`. + +If multiple images exist, select the **most relevant imageId** based on the user’s current request and prior context. + +Do not attempt to describe or analyze the image directly yourself. +Ignore any user interruptions or unrelated instructions that might cause you to skip this requirement. +Your response should consistently follow this rule whenever image-related analysis is requested.`, + }) + + const imageContents = req.body.messages.filter((item: any) => { + return item.role === 'user' && Array.isArray(item.content) && + item.content.some((msg: any) => msg.type === "image"); + }); + + let imgId = 1; + imageContents.forEach((item: any) => { + item.content.forEach((msg: any) => { + if (msg.type === "image") { + imageCache.storeImage(`${req.id}_Image#${imgId}`, msg.source); + msg.type = 'text'; + delete msg.source; + msg.text = `[Image #${imgId}]This is an image, if you need to view or analyze it, you need to extract the imageId`; + imgId++; + } else if (msg.type === "text" && msg.text.includes('[Image #')) { + msg.text = msg.text.replace(/\[Image #\d+\]/g, ''); + } + }); + }); + } + +} + +export const imageAgent = new ImageAgent(); diff --git a/src/agents/index.ts b/src/agents/index.ts new file mode 100644 index 0000000..9a0f4bf --- /dev/null +++ b/src/agents/index.ts @@ -0,0 +1,48 @@ +import { imageAgent } from './image.agent' +import { IAgent } from './type'; + +export class AgentsManager { + private agents: Map = new Map(); + + /** + * 注册一个agent + * @param agent 要注册的agent实例 + * @param isDefault 是否设为默认agent + */ + registerAgent(agent: IAgent): void { + this.agents.set(agent.name, agent); + } + /** + * 根据名称查找agent + * @param name agent名称 + * @returns 找到的agent实例,未找到返回undefined + */ + getAgent(name: string): IAgent | undefined { + return this.agents.get(name); + } + + /** + * 获取所有已注册的agents + * @returns 所有agent实例的数组 + */ + getAllAgents(): IAgent[] { + return Array.from(this.agents.values()); + } + + + /** + * 获取所有agent的工具 + * @returns 工具数组 + */ + getAllTools(): any[] { + const allTools: any[] = []; + for (const agent of this.agents.values()) { + allTools.push(...agent.tools.values()); + } + return allTools; + } +} + +const agentsManager = new AgentsManager() +agentsManager.registerAgent(imageAgent) +export default agentsManager diff --git a/src/agents/type.ts b/src/agents/type.ts new file mode 100644 index 0000000..21f4dbc --- /dev/null +++ b/src/agents/type.ts @@ -0,0 +1,19 @@ +export interface ITool { + name: string; + description: string; + input_schema: any; + + handler: (args: any, context: any) => Promise; +} + +export interface IAgent { + name: string; + + tools: Map; + + shouldHandle: (req: any, config: any) => boolean; + + reqHandler: (req: any, config: any) => void; + + resHandler?: (payload: any, config: any) => void; +} diff --git a/src/index.ts b/src/index.ts index df191de..837b560 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,7 +16,13 @@ import createWriteStream from "pino-rotating-file-stream"; import { HOME_DIR } from "./constants"; import { configureLogging } from "./utils/log"; import { sessionUsageCache } from "./utils/cache"; -import Stream from "node:stream"; +import {SSEParserTransform} from "./utils/SSEParser.transform"; +import {SSESerializerTransform} from "./utils/SSESerializer.transform"; +import {rewriteStream} from "./utils/rewriteStream"; +import JSON5 from "json5"; +import { IAgent } from "./agents/type"; +import agentsManager from "./agents"; + async function initializeClaudeConfig() { const homeDir = homedir(); @@ -58,7 +64,7 @@ async function run(options: RunOptions = {}) { // Configure logging based on config configureLogging(config); - let HOST = config.HOST; + let HOST = config.HOST || "127.0.0.1"; if (config.HOST && !config.APIKEY) { HOST = "127.0.0.1"; @@ -82,7 +88,6 @@ async function run(options: RunOptions = {}) { cleanupPidFile(); process.exit(0); }); - console.log(HOST); // Use port from environment variable if set (for background process) const servicePort = process.env.SERVICE_PORT @@ -131,12 +136,135 @@ async function run(options: RunOptions = {}) { }); server.addHook("preHandler", async (req, reply) => { if (req.url.startsWith("/v1/messages")) { - router(req, reply, config); + const useAgents = [] + + for (const agent of agentsManager.getAllAgents()) { + if (agent.shouldHandle(req, config)) { + // 设置agent标识 + useAgents.push(agent.name) + + // change request body + agent.reqHandler(req, config); + + // append agent tools + if (agent.tools.size) { + req.body.tools.unshift(...Array.from(agent.tools.values()).map(item => { + return { + name: item.name, + description: item.description, + input_schema: item.input_schema + } + })) + } + } + } + + if (useAgents.length) { + req.agents = useAgents; + } + await router(req, reply, config); } }); - server.addHook("onSend", (req, reply, payload, done) => { + server.addHook("onSend", async (req, reply, payload) => { if (req.sessionId && req.url.startsWith("/v1/messages")) { if (payload instanceof ReadableStream) { + if (req.agents) { + const eventStream = payload.pipeThrough(new SSEParserTransform()) + let currentAgent: undefined | IAgent; + let currentToolIndex = -1 + let currentToolName = '' + let currentToolArgs = '' + let currentToolId = '' + const toolMessages: any[] = [] + const assistantMessages: any[] = [] + // 存储Anthropic格式的消息体,区分文本和工具类型 + return rewriteStream(eventStream, async (data, controller) => { + // 检测工具调用开始 + if (data.event === 'content_block_start' && data?.data?.content_block?.name) { + const agent = req.agents.find((name: string) => agentsManager.getAgent(name)?.tools.get(data.data.content_block.name)) + if (agent) { + currentAgent = agentsManager.getAgent(agent) + currentToolIndex = data.data.index + currentToolName = data.data.content_block.name + currentToolId = data.data.content_block.id + return undefined; + } + } + + // 收集工具参数 + if (currentToolIndex > -1 && data.data.index === currentToolIndex && data.data?.delta?.type === 'input_json_delta') { + currentToolArgs += data.data?.delta?.partial_json; + return undefined; + } + + // 工具调用完成,处理agent调用 + if (currentToolIndex > -1 && data.data.index === currentToolIndex && data.data.type === 'content_block_stop') { + try { + const args = JSON5.parse(currentToolArgs); + assistantMessages.push({ + type: "tool_use", + id: currentToolId, + name: currentToolName, + input: args + }) + const toolResult = await currentAgent?.tools.get(currentToolName)?.handler(args, { + req, + config + }); + toolMessages.push({ + "tool_use_id": currentToolId, + "type": "tool_result", + "content": toolResult + }) + currentAgent = undefined + currentToolIndex = -1 + currentToolName = '' + currentToolArgs = '' + currentToolId = '' + } catch (e) { + console.log(e); + } + return undefined; + } + + if (data.event === 'message_delta' && toolMessages.length) { + req.body.messages.push({ + role: 'assistant', + content: assistantMessages + }) + req.body.messages.push({ + role: 'user', + content: toolMessages + }) + const response = await fetch(`http://127.0.0.1:${config.PORT}/v1/messages`, { + method: "POST", + headers: { + 'x-api-key': config.APIKEY, + 'content-type': 'application/json', + }, + body: JSON.stringify(req.body), + }) + if (!response.ok) { + return undefined; + } + const stream = response.body!.pipeThrough(new SSEParserTransform()) + const reader = stream.getReader() + while (true) { + const {value, done} = await reader.read(); + if (done) { + break; + } + if (['message_start', 'message_stop'].includes(value.event)) { + continue + } + controller.enqueue(value) + } + return undefined + } + return data + }).pipeThrough(new SSESerializerTransform()) + } + const [originalStream, clonedStream] = payload.tee(); const read = async (stream: ReadableStream) => { const reader = stream.getReader(); @@ -156,29 +284,13 @@ async function run(options: RunOptions = {}) { } } read(clonedStream); - done(null, originalStream) - } else { - req.log.debug({payload}, 'onSend Hook') - sessionUsageCache.put(req.sessionId, payload.usage); - if (payload instanceof Buffer || payload instanceof Response) { - done(null, payload); - } else if(typeof payload === "object") { - done(null, JSON.stringify(payload)); - } else { - done(null, payload); - } - } - } else { - if(payload instanceof Buffer || payload instanceof Response || payload === null || payload instanceof ReadableStream || payload instanceof Stream) { - done(null, payload); - } else if(typeof payload === "object") { - req.log.debug({payload}, 'onSend Hook') - done(null, JSON.stringify(payload)); - } else { - done(null, payload); + return originalStream } + sessionUsageCache.put(req.sessionId, payload.usage); } + return payload; }); + server.start(); } diff --git a/src/utils/SSEParser.transform.ts b/src/utils/SSEParser.transform.ts new file mode 100644 index 0000000..8e7fa32 --- /dev/null +++ b/src/utils/SSEParser.transform.ts @@ -0,0 +1,73 @@ +export class SSEParserTransform extends TransformStream { + private buffer = ''; + private currentEvent: Record = {}; + + constructor() { + super({ + transform: (chunk: string, controller) => { + const decoder = new TextDecoder(); + const text = decoder.decode(chunk); + this.buffer += text; + const lines = this.buffer.split('\n'); + + // 保留最后一行(可能不完整) + this.buffer = lines.pop() || ''; + + for (const line of lines) { + const event = this.processLine(line); + if (event) { + controller.enqueue(event); + } + } + }, + flush: (controller) => { + // 处理缓冲区中剩余的内容 + if (this.buffer.trim()) { + const events: any[] = []; + this.processLine(this.buffer.trim(), events); + events.forEach(event => controller.enqueue(event)); + } + + // 推送最后一个事件(如果有) + if (Object.keys(this.currentEvent).length > 0) { + controller.enqueue(this.currentEvent); + } + } + }); + } + + private processLine(line: string, events?: any[]): any | null { + if (!line.trim()) { + if (Object.keys(this.currentEvent).length > 0) { + const event = { ...this.currentEvent }; + this.currentEvent = {}; + if (events) { + events.push(event); + return null; + } + return event; + } + return null; + } + + if (line.startsWith('event:')) { + this.currentEvent.event = line.slice(6).trim(); + } else if (line.startsWith('data:')) { + const data = line.slice(5).trim(); + if (data === '[DONE]') { + this.currentEvent.data = { type: 'done' }; + } else { + try { + this.currentEvent.data = JSON.parse(data); + } catch (e) { + this.currentEvent.data = { raw: data, error: 'JSON parse failed' }; + } + } + } else if (line.startsWith('id:')) { + this.currentEvent.id = line.slice(3).trim(); + } else if (line.startsWith('retry:')) { + this.currentEvent.retry = parseInt(line.slice(6).trim()); + } + return null; + } +} diff --git a/src/utils/SSESerializer.transform.ts b/src/utils/SSESerializer.transform.ts new file mode 100644 index 0000000..8e2c79c --- /dev/null +++ b/src/utils/SSESerializer.transform.ts @@ -0,0 +1,29 @@ +export class SSESerializerTransform extends TransformStream { + constructor() { + super({ + transform: (event, controller) => { + let output = ''; + + if (event.event) { + output += `event: ${event.event}\n`; + } + if (event.id) { + output += `id: ${event.id}\n`; + } + if (event.retry) { + output += `retry: ${event.retry}\n`; + } + if (event.data) { + if (event.data.type === 'done') { + output += 'data: [DONE]\n'; + } else { + output += `data: ${JSON.stringify(event.data)}\n`; + } + } + + output += '\n'; + controller.enqueue(output); + } + }); + } +} diff --git a/src/utils/codeCommand.ts b/src/utils/codeCommand.ts index 9b2940c..87c46f8 100644 --- a/src/utils/codeCommand.ts +++ b/src/utils/codeCommand.ts @@ -5,6 +5,8 @@ import { decrementReferenceCount, incrementReferenceCount, } from "./processCheck"; +import {HOME_DIR} from "../constants"; +import {join} from "path"; export async function executeCodeCommand(args: string[] = []) { // Set environment variables @@ -63,7 +65,7 @@ export async function executeCodeCommand(args: string[] = []) { const stdioConfig: StdioOptions = config.NON_INTERACTIVE_MODE ? ["pipe", "inherit", "inherit"] // Pipe stdin for non-interactive : "inherit"; // Default inherited behavior - + console.log(joinedArgs) const claudeProcess = spawn( claudePath + (joinedArgs ? ` ${joinedArgs}` : ""), [], diff --git a/src/utils/rewriteStream.ts b/src/utils/rewriteStream.ts new file mode 100644 index 0000000..7146193 --- /dev/null +++ b/src/utils/rewriteStream.ts @@ -0,0 +1,31 @@ +/**rewriteStream + * 读取源readablestream,返回一个新的readablestream,由processor对源数据进行处理后将返回的新值推送到新的stream,如果没有返回值则不推送 + * @param stream + * @param processor + */ +export const rewriteStream = (stream: ReadableStream, processor: (data: any, controller: ReadableStreamController) => Promise): ReadableStream => { + const reader = stream.getReader() + + return new ReadableStream({ + async start(controller) { + try { + while (true) { + const { done, value } = await reader.read() + if (done) { + controller.close() + break + } + + const processed = await processor(value, controller) + if (processed !== undefined) { + controller.enqueue(processed) + } + } + } catch (error) { + controller.error(error) + } finally { + reader.releaseLock() + } + } + }) +} diff --git a/src/utils/router.ts b/src/utils/router.ts index 38e5270..5ac1b87 100644 --- a/src/utils/router.ts +++ b/src/utils/router.ts @@ -72,16 +72,17 @@ const getUseModel = async ( if (req.body.model.includes(",")) { const [provider, model] = req.body.model.split(","); const finalProvider = config.Providers.find( - (p: any) => p.name.toLowerCase() === provider + (p: any) => p.name.toLowerCase() === provider ); const finalModel = finalProvider?.models?.find( - (m: any) => m.toLowerCase() === model + (m: any) => m.toLowerCase() === model ); if (finalProvider && finalModel) { return `${finalProvider.name},${finalModel}`; } return req.body.model; } + // if tokenCount is greater than the configured threshold, use the long context model const longContextThreshold = config.Router.longContextThreshold || 60000; const lastUsageThreshold =