From fb28e99fa44124bd1780e5f0bf1c8a471d65cdac Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Fri, 22 Aug 2025 13:08:29 -0700 Subject: [PATCH] chore: mdb stub (#912) --- src/browserServerBackend.ts | 2 +- src/extension/cdpRelay.ts | 2 +- src/mcp/http.ts | 17 ++ src/{utils => mcp}/manualPromise.ts | 0 src/mcp/mdb.ts | 239 ++++++++++++++++++++++++++++ src/mcp/proxyBackend.ts | 4 +- src/mcp/server.ts | 9 +- src/mcp/tool.ts | 4 + src/tab.ts | 2 +- src/vscode/host.ts | 4 +- tests/mdb.spec.ts | 217 +++++++++++++++++++++++++ 11 files changed, 489 insertions(+), 11 deletions(-) rename src/{utils => mcp}/manualPromise.ts (100%) create mode 100644 src/mcp/mdb.ts create mode 100644 tests/mdb.spec.ts diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index 4d18fce..10f5ff4 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -41,7 +41,7 @@ export class BrowserServerBackend implements ServerBackend { this._tools = filteredTools(config); } - async initialize(clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { + async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { let rootPath: string | undefined; if (roots.length > 0) { const firstRootUri = roots[0]?.uri; diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index 4a85656..3fd7b9c 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -28,7 +28,7 @@ import debug from 'debug'; import { WebSocket, WebSocketServer } from 'ws'; import { httpAddressToString } from '../mcp/http.js'; import { logUnhandledError } from '../utils/log.js'; -import { ManualPromise } from '../utils/manualPromise.js'; +import { ManualPromise } from '../mcp/manualPromise.js'; import { packageJSON } from '../utils/package.js'; import type websocket from 'ws'; diff --git a/src/mcp/http.ts b/src/mcp/http.ts index 6890ddc..7cebc07 100644 --- a/src/mcp/http.ts +++ b/src/mcp/http.ts @@ -32,6 +32,7 @@ const testDebug = debug('pw:mcp:test'); export async function startHttpServer(config: { host?: string, port?: number }, abortSignal?: AbortSignal): Promise { const { host, port } = config; const httpServer = http.createServer(); + decorateServer(httpServer); await new Promise((resolve, reject) => { httpServer.on('error', reject); abortSignal?.addEventListener('abort', () => { @@ -136,3 +137,19 @@ async function handleStreamable(serverBackendFactory: ServerBackendFactory, req: res.statusCode = 400; res.end('Invalid request'); } + +function decorateServer(server: net.Server) { + const sockets = new Set(); + server.on('connection', socket => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + + const close = server.close; + server.close = (callback?: (err?: Error) => void) => { + for (const socket of sockets) + socket.destroy(); + sockets.clear(); + return close.call(server, callback); + }; +} diff --git a/src/utils/manualPromise.ts b/src/mcp/manualPromise.ts similarity index 100% rename from src/utils/manualPromise.ts rename to src/mcp/manualPromise.ts diff --git a/src/mcp/mdb.ts b/src/mcp/mdb.ts new file mode 100644 index 0000000..f5082a3 --- /dev/null +++ b/src/mcp/mdb.ts @@ -0,0 +1,239 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import debug from 'debug'; +import { z } from 'zod'; + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { PingRequestSchema } from '@modelcontextprotocol/sdk/types.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; + +import { defineToolSchema } from './tool.js'; +import * as mcpServer from './server.js'; +import * as mcpHttp from './http.js'; +import { wrapInProcess } from './server.js'; +import { ManualPromise } from './manualPromise.js'; + +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +const mdbDebug = debug('pw:mcp:mdb'); +const errorsDebug = debug('pw:mcp:errors'); + +export class MDBBackend implements mcpServer.ServerBackend { + private _stack: { client: Client, toolNames: string[], resultPromise: ManualPromise | undefined }[] = []; + private _interruptPromise: ManualPromise | undefined; + private _topLevelBackend: mcpServer.ServerBackend; + private _initialized = false; + + constructor(topLevelBackend: mcpServer.ServerBackend) { + this._topLevelBackend = topLevelBackend; + } + + async initialize(server: mcpServer.Server): Promise { + if (this._initialized) + return; + this._initialized = true; + const transport = await wrapInProcess(this._topLevelBackend); + await this._pushClient(transport); + } + + async listTools(): Promise { + const response = await this._client().listTools(); + return response.tools; + } + + async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise { + if (name === pushToolsSchema.name) + return await this._pushTools(pushToolsSchema.inputSchema.parse(args || {})); + + const interruptPromise = new ManualPromise(); + this._interruptPromise = interruptPromise; + let [entry] = this._stack; + + // Pop the client while the tool is not found. + while (entry && !entry.toolNames.includes(name)) { + mdbDebug('popping client from stack for ', name); + this._stack.shift(); + await entry.client.close(); + entry = this._stack[0]; + } + if (!entry) + throw new Error(`Tool ${name} not found in the tool stack`); + + const resultPromise = new ManualPromise(); + entry.resultPromise = resultPromise; + + this._client().callTool({ + name, + arguments: args, + }).then(result => { + resultPromise.resolve(result as mcpServer.CallToolResult); + }).catch(e => { + mdbDebug('error in client call', e); + if (this._stack.length < 2) + throw e; + this._stack.shift(); + const prevEntry = this._stack[0]; + void prevEntry.resultPromise!.then(result => resultPromise.resolve(result)); + }); + const result = await Promise.race([interruptPromise, resultPromise]); + if (interruptPromise.isDone()) + mdbDebug('client call intercepted', result); + else + mdbDebug('client call result', result); + return result; + } + + private _client(): Client { + const [entry] = this._stack; + if (!entry) + throw new Error('No debugging backend available'); + return entry.client; + } + + private async _pushTools(params: { mcpUrl: string, introMessage?: string }): Promise { + mdbDebug('pushing tools to the stack', params.mcpUrl); + const transport = new StreamableHTTPClientTransport(new URL(params.mcpUrl)); + await this._pushClient(transport, params.introMessage); + return { content: [{ type: 'text', text: 'Tools pushed' }] }; + } + + private async _pushClient(transport: Transport, introMessage?: string): Promise { + mdbDebug('pushing client to the stack'); + const client = new Client({ name: 'Internal client', version: '0.0.0' }); + client.setRequestHandler(PingRequestSchema, () => ({})); + await client.connect(transport); + mdbDebug('connected to the new client'); + const { tools } = await client.listTools(); + this._stack.unshift({ client, toolNames: tools.map(tool => tool.name), resultPromise: undefined }); + mdbDebug('new tools added to the stack:', tools.map(tool => tool.name)); + mdbDebug('interrupting current call:', !!this._interruptPromise); + this._interruptPromise?.resolve({ + content: [{ + type: 'text', + text: introMessage || '', + }], + }); + this._interruptPromise = undefined; + return { content: [{ type: 'text', text: 'Tools pushed' }] }; + } +} + +const pushToolsSchema = defineToolSchema({ + name: 'mdb_push_tools', + title: 'Push MCP tools to the tools stack', + description: 'Push MCP tools to the tools stack', + inputSchema: z.object({ + mcpUrl: z.string(), + introMessage: z.string().optional(), + }), + type: 'readOnly', +}); + +export type ServerBackendOnPause = mcpServer.ServerBackend & { + requestSelfDestruct?: () => void; +}; + +export async function runMainBackend(backendFactory: mcpServer.ServerBackendFactory, options?: { port?: number }): Promise { + const mdbBackend = new MDBBackend(backendFactory.create()); + // Start HTTP unconditionally. + const factory: mcpServer.ServerBackendFactory = { + ...backendFactory, + create: () => mdbBackend + }; + const url = await startAsHttp(factory, { port: options?.port || 0 }); + process.env.PLAYWRIGHT_DEBUGGER_MCP = url; + + if (options?.port !== undefined) + return url; + + // Start stdio conditionally. + await mcpServer.connect(factory, new StdioServerTransport(), false); +} + +export async function runOnPauseBackendLoop(mdbUrl: string, backend: ServerBackendOnPause, introMessage: string) { + const wrappedBackend = new OnceTimeServerBackendWrapper(backend); + + const factory = { + name: 'on-pause-backend', + nameInConfig: 'on-pause-backend', + version: '0.0.0', + create: () => wrappedBackend, + }; + + const httpServer = await mcpHttp.startHttpServer({ port: 0 }); + await mcpHttp.installHttpTransport(httpServer, factory); + const url = mcpHttp.httpAddressToString(httpServer.address()); + + const client = new Client({ name: 'Internal client', version: '0.0.0' }); + client.setRequestHandler(PingRequestSchema, () => ({})); + const transport = new StreamableHTTPClientTransport(new URL(mdbUrl)); + await client.connect(transport); + + const pushToolsResult = await client.callTool({ + name: pushToolsSchema.name, + arguments: { + mcpUrl: url, + introMessage, + }, + }); + if (pushToolsResult.isError) + errorsDebug('Failed to push tools', pushToolsResult.content); + await transport.terminateSession(); + await client.close(); + + await wrappedBackend.waitForClosed(); + httpServer.close(); +} + +async function startAsHttp(backendFactory: mcpServer.ServerBackendFactory, options: { port: number }) { + const httpServer = await mcpHttp.startHttpServer(options); + await mcpHttp.installHttpTransport(httpServer, backendFactory); + return mcpHttp.httpAddressToString(httpServer.address()); +} + + +class OnceTimeServerBackendWrapper implements mcpServer.ServerBackend { + private _backend: ServerBackendOnPause; + private _selfDestructPromise = new ManualPromise(); + + constructor(backend: ServerBackendOnPause) { + this._backend = backend; + this._backend.requestSelfDestruct = () => this._selfDestructPromise.resolve(); + } + + async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { + await this._backend.initialize?.(server, clientVersion, roots); + } + + async listTools(): Promise { + return this._backend.listTools(); + } + + async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise { + return this._backend.callTool(name, args); + } + + serverClosed(server: mcpServer.Server) { + this._backend.serverClosed?.(server); + this._selfDestructPromise.resolve(); + } + + async waitForClosed() { + await this._selfDestructPromise; + } +} diff --git a/src/mcp/proxyBackend.ts b/src/mcp/proxyBackend.ts index da186c4..89e6868 100644 --- a/src/mcp/proxyBackend.ts +++ b/src/mcp/proxyBackend.ts @@ -21,7 +21,7 @@ import { zodToJsonSchema } from 'zod-to-json-schema'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { ListRootsRequestSchema, PingRequestSchema } from '@modelcontextprotocol/sdk/types.js'; -import type { ServerBackend, ClientVersion, Root } from './server.js'; +import type { ServerBackend, ClientVersion, Root, Server } from './server.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js'; @@ -44,7 +44,7 @@ export class ProxyBackend implements ServerBackend { this._contextSwitchTool = this._defineContextSwitchTool(); } - async initialize(clientVersion: ClientVersion, roots: Root[]): Promise { + async initialize(server: Server, clientVersion: ClientVersion, roots: Root[]): Promise { this._roots = roots; await this._setCurrentClient(this._mcpProviders[0]); } diff --git a/src/mcp/server.ts b/src/mcp/server.ts index e9b4944..5d60edb 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -31,11 +31,12 @@ const serverDebug = debug('pw:mcp:server'); const errorsDebug = debug('pw:mcp:errors'); export type ClientVersion = { name: string, version: string }; + export interface ServerBackend { - initialize?(clientVersion: ClientVersion, roots: Root[]): Promise; + initialize?(server: Server, clientVersion: ClientVersion, roots: Root[]): Promise; listTools(): Promise; callTool(name: string, args: CallToolRequest['params']['arguments']): Promise; - serverClosed?(): void; + serverClosed?(server: Server): void; } export type ServerBackendFactory = { @@ -99,13 +100,13 @@ export function createServer(name: string, version: string, backend: ServerBacke clientRoots = roots; } const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }; - await backend.initialize?.(clientVersion, clientRoots); + await backend.initialize?.(server, clientVersion, clientRoots); initializedPromiseResolve(); } catch (e) { errorsDebug(e); } }); - addServerListener(server, 'close', () => backend.serverClosed?.()); + addServerListener(server, 'close', () => backend.serverClosed?.(server)); return server; } diff --git a/src/mcp/tool.ts b/src/mcp/tool.ts index aff266e..3d86745 100644 --- a/src/mcp/tool.ts +++ b/src/mcp/tool.ts @@ -40,3 +40,7 @@ export function toMcpTool(tool: ToolSchema): mcpServer.Tool { }, }; } + +export function defineToolSchema(tool: ToolSchema): ToolSchema { + return tool; +} diff --git a/src/tab.ts b/src/tab.ts index 1b3fff0..5ef34af 100644 --- a/src/tab.ts +++ b/src/tab.ts @@ -18,7 +18,7 @@ import { EventEmitter } from 'events'; import * as playwright from 'playwright'; import { callOnPageNoTrace, waitForCompletion } from './tools/utils.js'; import { logUnhandledError } from './utils/log.js'; -import { ManualPromise } from './utils/manualPromise.js'; +import { ManualPromise } from './mcp/manualPromise.js'; import { ModalState } from './tools/tool.js'; import type { Context } from './context.js'; diff --git a/src/vscode/host.ts b/src/vscode/host.ts index ce78cc4..f871c1f 100644 --- a/src/vscode/host.ts +++ b/src/vscode/host.ts @@ -52,7 +52,7 @@ class VSCodeProxyBackend implements ServerBackend { this._contextSwitchTool = this._defineContextSwitchTool(); } - async initialize(clientVersion: ClientVersion, roots: Root[]): Promise { + async initialize(server: mcpServer.Server, clientVersion: ClientVersion, roots: Root[]): Promise { this._clientVersion = clientVersion; this._roots = roots; const transport = await this._defaultTransportFactory(); @@ -76,7 +76,7 @@ class VSCodeProxyBackend implements ServerBackend { }) as CallToolResult; } - serverClosed?(): void { + serverClosed?(server: mcpServer.Server): void { void this._currentClient?.close().catch(logUnhandledError); } diff --git a/tests/mdb.spec.ts b/tests/mdb.spec.ts new file mode 100644 index 0000000..d694b58 --- /dev/null +++ b/tests/mdb.spec.ts @@ -0,0 +1,217 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'zod'; +import zodToJsonSchema from 'zod-to-json-schema'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; + +import { runMainBackend, runOnPauseBackendLoop } from '../src/mcp/mdb.js'; + +import { test, expect } from './fixtures.js'; + +import type * as mcpServer from '../src/mcp/server.js'; +import type { ServerBackendOnPause } from '../src/mcp/mdb.js'; + +test('call top level tool', async () => { + const { mdbUrl } = await startMDBAndCLI(); + const mdbClient = await createMDBClient(mdbUrl); + + const { tools } = await mdbClient.client.listTools(); + expect(tools).toEqual([{ + name: 'cli_echo', + description: 'Echo a message', + inputSchema: expect.any(Object), + }, { + name: 'cli_pause_in_gdb', + description: 'Pause in gdb', + inputSchema: expect.any(Object), + }, { + name: 'cli_pause_in_gdb_twice', + description: 'Pause in gdb twice', + inputSchema: expect.any(Object), + } + ]); + + const echoResult = await mdbClient.client.callTool({ + name: 'cli_echo', + arguments: { + message: 'Hello, world!', + }, + }); + expect(echoResult.content).toEqual([{ type: 'text', text: 'Echo: Hello, world!' }]); + + await mdbClient.close(); +}); + +test('pause on error', async () => { + const { mdbUrl } = await startMDBAndCLI(); + const mdbClient = await createMDBClient(mdbUrl); + + // Make a call that results in a recoverable error. + const interruptResult = await mdbClient.client.callTool({ + name: 'cli_pause_in_gdb', + arguments: {}, + }); + expect(interruptResult.content).toEqual([{ type: 'text', text: 'Paused on exception' }]); + + // List new inner tools. + const { tools } = await mdbClient.client.listTools(); + expect(tools).toEqual([ + expect.objectContaining({ + name: 'gdb_bt', + }), + expect.objectContaining({ + name: 'gdb_continue', + }), + ]); + + // Call the new inner tool. + const btResult = await mdbClient.client.callTool({ + name: 'gdb_bt', + arguments: {}, + }); + expect(btResult.content).toEqual([{ type: 'text', text: 'Backtrace' }]); + + // Continue execution. + const continueResult = await mdbClient.client.callTool({ + name: 'gdb_continue', + arguments: {}, + }); + expect(continueResult.content).toEqual([{ type: 'text', text: 'Done' }]); + + await mdbClient.close(); +}); + +test('pause on error twice', async () => { + const { mdbUrl } = await startMDBAndCLI(); + const mdbClient = await createMDBClient(mdbUrl); + + // Make a call that results in a recoverable error. + const result = await mdbClient.client.callTool({ + name: 'cli_pause_in_gdb_twice', + arguments: {}, + }); + expect(result.content).toEqual([{ type: 'text', text: 'Paused on exception 1' }]); + + // Continue execution. + const continueResult1 = await mdbClient.client.callTool({ + name: 'gdb_continue', + arguments: {}, + }); + expect(continueResult1.content).toEqual([{ type: 'text', text: 'Paused on exception 2' }]); + + const continueResult2 = await mdbClient.client.callTool({ + name: 'gdb_continue', + arguments: {}, + }); + expect(continueResult2.content).toEqual([{ type: 'text', text: 'Done' }]); + + await mdbClient.close(); +}); + +async function startMDBAndCLI(): Promise<{ mdbUrl: string }> { + const mdbUrlBox = { mdbUrl: undefined as string | undefined }; + const cliBackendFactory = { + name: 'CLI', + nameInConfig: 'cli', + version: '0.0.0', + create: () => new CLIBackend(mdbUrlBox) + }; + + const mdbUrl = (await runMainBackend(cliBackendFactory, { port: 0 }))!; + mdbUrlBox.mdbUrl = mdbUrl; + return { mdbUrl }; +} + +async function createMDBClient(mdbUrl: string): Promise<{ client: Client, close: () => Promise }> { + const client = new Client({ name: 'Internal client', version: '0.0.0' }); + const transport = new StreamableHTTPClientTransport(new URL(mdbUrl)); + await client.connect(transport); + return { + client, + close: async () => { + await transport.terminateSession(); + await client.close(); + } + }; +} + +class CLIBackend implements mcpServer.ServerBackend { + constructor(private readonly mdbUrlBox: { mdbUrl: string | undefined }) {} + + async listTools(): Promise { + return [{ + name: 'cli_echo', + description: 'Echo a message', + inputSchema: zodToJsonSchema(z.object({ message: z.string() })) as any, + }, { + name: 'cli_pause_in_gdb', + description: 'Pause in gdb', + inputSchema: zodToJsonSchema(z.object({})) as any, + }, { + name: 'cli_pause_in_gdb_twice', + description: 'Pause in gdb twice', + inputSchema: zodToJsonSchema(z.object({})) as any, + }]; + } + + async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise { + if (name === 'cli_echo') + return { content: [{ type: 'text', text: 'Echo: ' + (args?.message as string) }] }; + if (name === 'cli_pause_in_gdb') { + await runOnPauseBackendLoop(this.mdbUrlBox.mdbUrl!, new GDBBackend(), 'Paused on exception'); + return { content: [{ type: 'text', text: 'Done' }] }; + } + if (name === 'cli_pause_in_gdb_twice') { + await runOnPauseBackendLoop(this.mdbUrlBox.mdbUrl!, new GDBBackend(), 'Paused on exception 1'); + await runOnPauseBackendLoop(this.mdbUrlBox.mdbUrl!, new GDBBackend(), 'Paused on exception 2'); + return { content: [{ type: 'text', text: 'Done' }] }; + } + throw new Error(`Unknown tool: ${name}`); + } +} + +class GDBBackend implements ServerBackendOnPause { + private _server!: mcpServer.Server; + + async initialize(server: mcpServer.Server): Promise { + this._server = server; + } + + async listTools(): Promise { + return [{ + name: 'gdb_bt', + description: 'Print backtrace', + inputSchema: zodToJsonSchema(z.object({})) as any, + }, { + name: 'gdb_continue', + description: 'Continue execution', + inputSchema: zodToJsonSchema(z.object({})) as any, + }]; + } + + async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise { + if (name === 'gdb_bt') + return { content: [{ type: 'text', text: 'Backtrace' }] }; + if (name === 'gdb_continue') { + (this as ServerBackendOnPause).requestSelfDestruct?.(); + // Stall + await new Promise(f => setTimeout(f, 1000)); + } + throw new Error(`Unknown tool: ${name}`); + } +}