chore: extract pure mcp server helpers (#751)

This commit is contained in:
Pavel Feldman
2025-07-24 12:57:01 -07:00
committed by GitHub
parent bd34e9d7e9
commit c63b7823e1
12 changed files with 300 additions and 469 deletions

1
src/mcp/README.md Normal file
View File

@@ -0,0 +1 @@
- Generic MCP utils, no dependencies on Playwright here.

105
src/mcp/server.ts Normal file
View File

@@ -0,0 +1,105 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { z } from 'zod';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type ClientVersion = Implementation;
export type ToolResponse = {
content: (TextContent | ImageContent)[];
isError?: boolean;
};
export type ToolSchema<Input extends z.Schema> = {
name: string;
title: string;
description: string;
inputSchema: Input;
type: 'readOnly' | 'destructive';
};
export type ToolHandler = (toolName: string, params: any) => Promise<ToolResponse>;
export interface ServerBackend {
name: string;
version: string;
initialize?(): Promise<void>;
tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverInitialized?(version: ClientVersion | undefined): void;
serverClosed?(): void;
}
export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) {
const backend = serverBackendFactory();
await backend.initialize?.();
const server = createServer(backend);
await server.connect(transport);
}
export function createServer(backend: ServerBackend): Server {
const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: {
tools: {},
}
});
const tools = backend.tools();
server.setRequestHandler(ListToolsRequestSchema, async () => {
return { tools: tools.map(tool => ({
name: tool.name,
description: tool.description,
inputSchema: zodToJsonSchema(tool.inputSchema),
annotations: {
title: tool.title,
readOnlyHint: tool.type === 'readOnly',
destructiveHint: tool.type === 'destructive',
openWorldHint: true,
},
})) };
});
server.setRequestHandler(CallToolRequestSchema, async request => {
const errorResult = (...messages: string[]) => ({
content: [{ type: 'text', text: messages.join('\n') }],
isError: true,
});
const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema<any>;
if (!tool)
return errorResult(`Tool "${request.params.name}" not found`);
try {
return await backend.callTool(tool, tool.inputSchema.parse(request.params.arguments || {}));
} catch (error) {
return errorResult(String(error));
}
});
if (backend.serverInitialized)
server.oninitialized = () => backend.serverInitialized!(server.getClientVersion());
if (backend.serverClosed)
server.onclose = () => backend.serverClosed!();
return server;
}

137
src/mcp/transport.ts Normal file
View File

@@ -0,0 +1,137 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import http from 'http';
import crypto from 'crypto';
import debug from 'debug';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { httpAddressToString, startHttpServer } from '../httpServer.js';
import * as mcpServer from './server.js';
import type { ServerBackendFactory } from './server.js';
export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number }) {
if (options.port !== undefined) {
const httpServer = await startHttpServer(options);
startHttpTransport(httpServer, serverBackendFactory);
} else {
await startStdioTransport(serverBackendFactory);
}
}
async function startStdioTransport(serverBackendFactory: ServerBackendFactory) {
await mcpServer.connect(serverBackendFactory, new StdioServerTransport());
}
const testDebug = debug('pw:mcp:test');
async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) {
if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
res.statusCode = 400;
return res.end('Missing sessionId');
}
const transport = sessions.get(sessionId);
if (!transport) {
res.statusCode = 404;
return res.end('Session not found');
}
return await transport.handlePostMessage(req, res);
} else if (req.method === 'GET') {
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
testDebug(`create SSE session: ${transport.sessionId}`);
await mcpServer.connect(serverBackendFactory, transport);
res.on('close', () => {
testDebug(`delete SSE session: ${transport.sessionId}`);
sessions.delete(transport.sessionId);
});
return;
}
res.statusCode = 405;
res.end('Method not allowed');
}
async function handleStreamable(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) {
const transport = sessions.get(sessionId);
if (!transport) {
res.statusCode = 404;
res.end('Session not found');
return;
}
return await transport.handleRequest(req, res);
}
if (req.method === 'POST') {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: async sessionId => {
testDebug(`create http session: ${transport.sessionId}`);
await mcpServer.connect(serverBackendFactory, transport);
sessions.set(sessionId, transport);
}
});
transport.onclose = () => {
if (!transport.sessionId)
return;
sessions.delete(transport.sessionId);
testDebug(`delete http session: ${transport.sessionId}`);
};
await transport.handleRequest(req, res);
return;
}
res.statusCode = 400;
res.end('Invalid request');
}
function startHttpTransport(httpServer: http.Server, serverBackendFactory: ServerBackendFactory) {
const sseSessions = new Map();
const streamableSessions = new Map();
httpServer.on('request', async (req, res) => {
const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/sse'))
await handleSSE(serverBackendFactory, req, res, url, sseSessions);
else
await handleStreamable(serverBackendFactory, req, res, streamableSessions);
});
const url = httpAddressToString(httpServer.address());
const message = [
`Listening on ${url}`,
'Put this in your client config:',
JSON.stringify({
'mcpServers': {
'playwright': {
'url': `${url}/mcp`
}
}
}, undefined, 2),
'For legacy SSE transport support, you can use the /sse endpoint instead.',
].join('\n');
// eslint-disable-next-line no-console
console.error(message);
}