158 lines
5.4 KiB
TypeScript
158 lines
5.4 KiB
TypeScript
/**
|
|
* Copyright (c) Microsoft Corporation.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
import debug from 'debug';
|
|
|
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
|
import { httpAddressToString, installHttpTransport, startHttpServer } from './http.js';
|
|
import { InProcessTransport } from './inProcessTransport.js';
|
|
|
|
import type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextprotocol/sdk/types.js';
|
|
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
|
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
|
export type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextprotocol/sdk/types.js';
|
|
|
|
const serverDebug = debug('pw:mcp:server');
|
|
const errorsDebug = debug('pw:mcp:errors');
|
|
|
|
export type ClientVersion = { name: string, version: string };
|
|
export interface ServerBackend {
|
|
initialize?(clientVersion: ClientVersion, roots: Root[]): Promise<void>;
|
|
listTools(): Promise<Tool[]>;
|
|
callTool(name: string, args: CallToolRequest['params']['arguments']): Promise<CallToolResult>;
|
|
serverClosed?(): void;
|
|
}
|
|
|
|
export type ServerBackendFactory = {
|
|
name: string;
|
|
nameInConfig: string;
|
|
version: string;
|
|
create: () => ServerBackend;
|
|
};
|
|
|
|
export async function connect(factory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
|
|
const server = createServer(factory.name, factory.version, factory.create(), runHeartbeat);
|
|
await server.connect(transport);
|
|
}
|
|
|
|
export async function wrapInProcess(backend: ServerBackend): Promise<Transport> {
|
|
const server = createServer('Internal', '0.0.0', backend, false);
|
|
return new InProcessTransport(server);
|
|
}
|
|
|
|
export function createServer(name: string, version: string, backend: ServerBackend, runHeartbeat: boolean): Server {
|
|
let initializedPromiseResolve = () => {};
|
|
const initializedPromise = new Promise<void>(resolve => initializedPromiseResolve = resolve);
|
|
const server = new Server({ name, version }, {
|
|
capabilities: {
|
|
tools: {},
|
|
}
|
|
});
|
|
|
|
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
|
serverDebug('listTools');
|
|
await initializedPromise;
|
|
const tools = await backend.listTools();
|
|
return { tools };
|
|
});
|
|
|
|
let heartbeatRunning = false;
|
|
server.setRequestHandler(CallToolRequestSchema, async request => {
|
|
serverDebug('callTool', request);
|
|
await initializedPromise;
|
|
|
|
if (runHeartbeat && !heartbeatRunning) {
|
|
heartbeatRunning = true;
|
|
startHeartbeat(server);
|
|
}
|
|
|
|
try {
|
|
return await backend.callTool(request.params.name, request.params.arguments || {});
|
|
} catch (error) {
|
|
return {
|
|
content: [{ type: 'text', text: '### Result\n' + String(error) }],
|
|
isError: true,
|
|
};
|
|
}
|
|
});
|
|
addServerListener(server, 'initialized', async () => {
|
|
try {
|
|
const capabilities = server.getClientCapabilities();
|
|
let clientRoots: Root[] = [];
|
|
if (capabilities?.roots) {
|
|
const { roots } = await server.listRoots(undefined, { timeout: 2_000 }).catch(() => ({ roots: [] }));
|
|
clientRoots = roots;
|
|
}
|
|
const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' };
|
|
await backend.initialize?.(clientVersion, clientRoots);
|
|
initializedPromiseResolve();
|
|
} catch (e) {
|
|
errorsDebug(e);
|
|
}
|
|
});
|
|
addServerListener(server, 'close', () => backend.serverClosed?.());
|
|
return server;
|
|
}
|
|
|
|
const startHeartbeat = (server: Server) => {
|
|
const beat = () => {
|
|
Promise.race([
|
|
server.ping(),
|
|
new Promise((_, reject) => setTimeout(() => reject(new Error('ping timeout')), 5000)),
|
|
]).then(() => {
|
|
setTimeout(beat, 3000);
|
|
}).catch(() => {
|
|
void server.close();
|
|
});
|
|
};
|
|
|
|
beat();
|
|
};
|
|
|
|
function addServerListener(server: Server, event: 'close' | 'initialized', listener: () => void) {
|
|
const oldListener = server[`on${event}`];
|
|
server[`on${event}`] = () => {
|
|
oldListener?.();
|
|
listener();
|
|
};
|
|
}
|
|
|
|
export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number }) {
|
|
if (options.port === undefined) {
|
|
await connect(serverBackendFactory, new StdioServerTransport(), false);
|
|
return;
|
|
}
|
|
|
|
const httpServer = await startHttpServer(options);
|
|
await installHttpTransport(httpServer, serverBackendFactory);
|
|
const url = httpAddressToString(httpServer.address());
|
|
|
|
const mcpConfig: any = { mcpServers: { } };
|
|
mcpConfig.mcpServers[serverBackendFactory.nameInConfig] = {
|
|
url: `${url}/mcp`
|
|
};
|
|
const message = [
|
|
`Listening on ${url}`,
|
|
'Put this in your client config:',
|
|
JSON.stringify(mcpConfig, undefined, 2),
|
|
'For legacy SSE transport support, you can use the /sse endpoint instead.',
|
|
].join('\n');
|
|
// eslint-disable-next-line no-console
|
|
console.error(message);
|
|
}
|