From 91d5d24cab55669e4e03a1415f426b689777cfed Mon Sep 17 00:00:00 2001 From: Yury Semikhatsky Date: Fri, 15 Aug 2025 11:23:59 -0700 Subject: [PATCH] chore: handle list roots in the server, with timeout (#898) --- src/browserServerBackend.ts | 8 +-- src/mcp/proxyBackend.ts | 17 ++---- src/mcp/server.ts | 23 ++++++-- tests/roots.spec.ts | 114 +++++++++++++++++------------------- 4 files changed, 78 insertions(+), 84 deletions(-) diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index e71a15d..1170b31 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -45,11 +45,9 @@ export class BrowserServerBackend implements ServerBackend { this._tools = filteredTools(config); } - async initialize(server: mcpServer.Server): Promise { - const capabilities = server.getClientCapabilities(); + async initialize(clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { let rootPath: string | undefined; - if (capabilities?.roots) { - const { roots } = await server.listRoots(); + if (roots.length > 0) { const firstRootUri = roots[0]?.uri; const url = firstRootUri ? new URL(firstRootUri) : undefined; rootPath = url ? fileURLToPath(url) : undefined; @@ -60,7 +58,7 @@ export class BrowserServerBackend implements ServerBackend { config: this._config, browserContextFactory: this._browserContextFactory, sessionLog: this._sessionLog, - clientInfo: { ...server.getClientVersion(), rootPath }, + clientInfo: { ...clientVersion, rootPath }, }); } diff --git a/src/mcp/proxyBackend.ts b/src/mcp/proxyBackend.ts index e4083b5..c639fd5 100644 --- a/src/mcp/proxyBackend.ts +++ b/src/mcp/proxyBackend.ts @@ -23,10 +23,9 @@ import { logUnhandledError } from '../utils/log.js'; import { packageJSON } from '../utils/package.js'; -import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; -import type { ServerBackend } from './server.js'; +import type { ServerBackend, ClientVersion, Root } from './server.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; -import type { Root, Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js'; +import type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js'; export type MCPProvider = { name: string; @@ -48,14 +47,8 @@ export class ProxyBackend implements ServerBackend { this._contextSwitchTool = this._defineContextSwitchTool(); } - async initialize(server: Server): Promise { - const version = server.getClientVersion(); - const capabilities = server.getClientCapabilities(); - if (capabilities?.roots && version && clientsWithRoots.includes(version.name)) { - const { roots } = await server.listRoots(); - this._roots = roots; - } - + async initialize(clientVersion: ClientVersion, roots: Root[]): Promise { + this._roots = roots; await this._setCurrentClient(this._mcpProviders[0]); } @@ -136,5 +129,3 @@ export class ProxyBackend implements ServerBackend { this._currentClient = client; } } - -const clientsWithRoots = ['Visual Studio Code', 'Visual Studio Code - Insiders']; diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 3ac389e..80c1461 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -20,17 +20,18 @@ import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprot import { ManualPromise } from '../utils/manualPromise.js'; import { logUnhandledError } from '../utils/log.js'; -import type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js'; +import type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextprotocol/sdk/types.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; export type { Server } from '@modelcontextprotocol/sdk/server/index.js'; -export type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js'; +export type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextprotocol/sdk/types.js'; const serverDebug = debug('pw:mcp:server'); +export type ClientVersion = { name: string, version: string }; export interface ServerBackend { name: string; version: string; - initialize?(server: Server): Promise; + initialize?(clientVersion: ClientVersion, roots: Root[]): Promise; listTools(): Promise; callTool(name: string, args: CallToolRequest['params']['arguments']): Promise; serverClosed?(): void; @@ -78,8 +79,20 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser }; } }); - addServerListener(server, 'initialized', () => { - backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError); + addServerListener(server, 'initialized', async () => { + try { + const capabilities = server.getClientCapabilities(); + let clientRoots: Root[] = []; + if (capabilities?.roots) { + const { roots } = await server.listRoots(undefined, { timeout: 2_000 }).catch(() => ({ roots: [] })); + clientRoots = roots; + } + const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }; + await backend.initialize?.(clientVersion, clientRoots); + initializedPromise.resolve(); + } catch (e) { + logUnhandledError(e); + } }); addServerListener(server, 'close', () => backend.serverClosed?.()); return server; diff --git a/tests/roots.spec.ts b/tests/roots.spec.ts index a94191e..ffcd8b5 100644 --- a/tests/roots.spec.ts +++ b/tests/roots.spec.ts @@ -23,65 +23,57 @@ import { createHash } from '../src/utils/guid.js'; const p = process.platform === 'win32' ? 'c:\\non\\existent\\folder' : '/non/existent/folder'; -for (const mode of ['default', 'proxy']) { - const extraArgs = mode === 'proxy' ? ['--connect-tool'] : []; - - test.describe(`${mode} mode`, () => { - test('should use separate user data by root path', async ({ startClient, server }, testInfo) => { - const { client } = await startClient({ - args: extraArgs, - clientName: 'Visual Studio Code', // Simulate VS Code client, roots only work with it - roots: [ - { - name: 'test', - uri: 'file://' + p.replace(/\\/g, '/'), - } - ], - }); - - await client.callTool({ - name: 'browser_navigate', - arguments: { url: server.HELLO_WORLD }, - }); - - const hash = createHash(p); - const [file] = await fs.promises.readdir(testInfo.outputPath('ms-playwright')); - expect(file).toContain(hash); - }); - - - test('check that trace is saved in workspace', async ({ startClient, server }, testInfo) => { - const rootPath = testInfo.outputPath('workspace'); - const { client } = await startClient({ - args: ['--save-trace', ...extraArgs], - clientName: 'Visual Studio Code - Insiders', // Simulate VS Code client, roots only work with it - roots: [ - { - name: 'workspace', - uri: pathToFileURL(rootPath).toString(), - }, - ], - }); - - expect(await client.callTool({ - name: 'browser_navigate', - arguments: { url: server.HELLO_WORLD }, - })).toHaveResponse({ - code: expect.stringContaining(`page.goto('http://localhost`), - }); - - const [file] = await fs.promises.readdir(path.join(rootPath, '.playwright-mcp')); - expect(file).toContain('traces'); - }); - - test('should list all tools when listRoots is slow', async ({ startClient, server }, testInfo) => { - const { client } = await startClient({ - clientName: 'Visual Studio Code', // Simulate VS Code client, roots only work with it - roots: [], - rootsResponseDelay: 1000, - }); - const tools = await client.listTools(); - expect(tools.tools.length).toBeGreaterThan(20); - }); +test('should use separate user data by root path', async ({ startClient, server }, testInfo) => { + const { client } = await startClient({ + clientName: 'Visual Studio Code', + roots: [ + { + name: 'test', + uri: 'file://' + p.replace(/\\/g, '/'), + } + ], }); -} + + await client.callTool({ + name: 'browser_navigate', + arguments: { url: server.HELLO_WORLD }, + }); + + const hash = createHash(p); + const [file] = await fs.promises.readdir(testInfo.outputPath('ms-playwright')); + expect(file).toContain(hash); +}); + +test('check that trace is saved in workspace', async ({ startClient, server }, testInfo) => { + const rootPath = testInfo.outputPath('workspace'); + const { client } = await startClient({ + args: ['--save-trace'], + clientName: 'My client', + roots: [ + { + name: 'workspace', + uri: pathToFileURL(rootPath).toString(), + }, + ], + }); + + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { url: server.HELLO_WORLD }, + })).toHaveResponse({ + code: expect.stringContaining(`page.goto('http://localhost`), + }); + + const [file] = await fs.promises.readdir(path.join(rootPath, '.playwright-mcp')); + expect(file).toContain('traces'); +}); + +test('should list all tools when listRoots is slow', async ({ startClient, server }, testInfo) => { + const { client } = await startClient({ + clientName: 'Another custom client', + roots: [], + rootsResponseDelay: 1000, + }); + const tools = await client.listTools(); + expect(tools.tools.length).toBeGreaterThan(20); +});