chore: refactor initialize (#812)

This commit is contained in:
Pavel Feldman
2025-08-01 13:06:36 -07:00
committed by GitHub
parent 7c07cc86eb
commit ffe0117456
3 changed files with 44 additions and 21 deletions

View File

@@ -18,11 +18,18 @@ import { z } from 'zod';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { ManualPromise } from '../manualPromise.js';
import { logUnhandledError } from '../log.js';
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export type ClientVersion = Implementation;
export type ClientCapabilities = {
roots?: {
listRoots?: boolean
};
};
export type ToolResponse = {
content: (TextContent | ImageContent)[];
@@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise<ToolRespons
export interface ServerBackend {
name: string;
version: string;
initialize?(): Promise<void>;
initialize?(server: Server): Promise<void>;
tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverInitialized?(version: ClientVersion | undefined): void;
serverClosed?(): void;
}
@@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
const backend = serverBackendFactory();
await backend.initialize?.();
const server = createServer(backend, runHeartbeat);
await server.connect(transport);
}
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
const initializedPromise = new ManualPromise<void>();
const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: {
tools: {},
@@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
let heartbeatRunning = false;
server.setRequestHandler(CallToolRequestSchema, async request => {
await initializedPromise;
if (runHeartbeat && !heartbeatRunning) {
heartbeatRunning = true;
startHeartbeat(server);
@@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
return errorResult(String(error));
}
});
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
addServerListener(server, 'initialized', () => {
backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
});
addServerListener(server, 'close', () => backend.serverClosed?.());
return server;
}