diff --git a/src/inProcessClient.ts b/src/inProcessClient.ts index 448e1ac..390954c 100644 --- a/src/inProcessClient.ts +++ b/src/inProcessClient.ts @@ -16,12 +16,13 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { ListRootsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { BrowserContextFactory } from './browserContextFactory.js'; import { BrowserServerBackend } from './browserServerBackend.js'; import { InProcessTransport } from './mcp/inProcessTransport.js'; import * as mcpServer from './mcp/server.js'; -import { packageJSON } from './package.js'; +import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { FullConfig } from './config.js'; import type { ClientFactory } from './mcp/proxyBackend.js'; @@ -39,13 +40,20 @@ export class InProcessClientFactory implements ClientFactory { this._config = config; } - async create(): Promise { - const client = new Client({ - name: this.name, - version: packageJSON.version - }); - const server = mcpServer.createServer(new BrowserServerBackend(this._config, this._contextFactory), false); - await client.connect(new InProcessTransport(server)); + async create(server: Server): Promise { + const client = new Client(server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }); + const clientCapabilities = server.getClientCapabilities(); + if (clientCapabilities) + client.registerCapabilities(clientCapabilities); + + if (clientCapabilities?.roots) { + client.setRequestHandler(ListRootsRequestSchema, async () => { + return await server.listRoots(); + }); + } + + const delegate = mcpServer.createServer(new BrowserServerBackend(this._config, this._contextFactory), false); + await client.connect(new InProcessTransport(delegate)); await client.ping(); return client; } diff --git a/src/mcp/proxyBackend.ts b/src/mcp/proxyBackend.ts index c2e6eef..d90d1e2 100644 --- a/src/mcp/proxyBackend.ts +++ b/src/mcp/proxyBackend.ts @@ -29,7 +29,7 @@ type NonEmptyArray = [T, ...T[]]; export type ClientFactory = { name: string; description: string; - create(): Promise; + create(server: Server): Promise; }; export type ClientFactoryList = NonEmptyArray; @@ -42,6 +42,7 @@ export class ProxyBackend implements ServerBackend { private _currentClient: Client | undefined; private _contextSwitchTool: Tool; private _tools: ToolSchema[] = []; + private _server: Server | undefined; constructor(clientFactories: ClientFactoryList) { this._clientFactories = clientFactories; @@ -49,6 +50,7 @@ export class ProxyBackend implements ServerBackend { } async initialize(server: Server): Promise { + this._server = server; await this._setCurrentClient(this._clientFactories[0]); } @@ -118,7 +120,7 @@ export class ProxyBackend implements ServerBackend { private async _setCurrentClient(factory: ClientFactory) { await this._currentClient?.close(); - this._currentClient = await factory.create(); + this._currentClient = await factory.create(this._server!); const tools = await this._currentClient.listTools(); this._tools = tools.tools.map(tool => ({ name: tool.name, diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 5da54be..f726f58 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -76,7 +76,7 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser return { tools: tools.map(tool => ({ name: tool.name, description: tool.description, - inputSchema: zodToJsonSchema(tool.inputSchema), + inputSchema: tool.inputSchema instanceof z.ZodType ? zodToJsonSchema(tool.inputSchema) : tool.inputSchema, annotations: { title: tool.title, readOnlyHint: tool.type === 'readOnly', diff --git a/tests/capabilities.spec.ts b/tests/capabilities.spec.ts index 5f33035..61f9f39 100644 --- a/tests/capabilities.spec.ts +++ b/tests/capabilities.spec.ts @@ -46,6 +46,40 @@ test('test snapshot tool list', async ({ client }) => { ])); }); +test('test tool list proxy mode', async ({ startClient }) => { + const { client } = await startClient({ + args: ['--connect-tool'], + }); + const { tools } = await client.listTools(); + expect(new Set(tools.map(t => t.name))).toEqual(new Set([ + 'browser_click', + 'browser_connect', // the extra tool + 'browser_console_messages', + 'browser_drag', + 'browser_evaluate', + 'browser_file_upload', + 'browser_handle_dialog', + 'browser_hover', + 'browser_select_option', + 'browser_type', + 'browser_close', + 'browser_install', + 'browser_navigate_back', + 'browser_navigate_forward', + 'browser_navigate', + 'browser_network_requests', + 'browser_press_key', + 'browser_resize', + 'browser_snapshot', + 'browser_tab_close', + 'browser_tab_list', + 'browser_tab_new', + 'browser_tab_select', + 'browser_take_screenshot', + 'browser_wait_for', + ])); +}); + test('test capabilities (pdf)', async ({ startClient }) => { const { client } = await startClient({ args: ['--caps=pdf'], diff --git a/tests/roots.spec.ts b/tests/roots.spec.ts index 9b3ee20..1529aff 100644 --- a/tests/roots.spec.ts +++ b/tests/roots.spec.ts @@ -23,48 +23,55 @@ import { createHash } from '../src/utils.js'; const p = process.platform === 'win32' ? 'c:\\non\\existent\\folder' : '/non/existent/folder'; -test('should use separate user data by root path', async ({ startClient, server }, testInfo) => { - const { client } = await startClient({ - clientName: 'Visual Studio Code', // Simulate VS Code client, roots only work with it - roots: [ - { - name: 'test', - uri: 'file://' + p.replace(/\\/g, '/'), - } - ], +for (const mode of ['default', 'proxy']) { + const extraArgs = mode === 'proxy' ? ['--connect-tool'] : []; + + test.describe(`${mode} mode`, () => { + test('should use separate user data by root path', async ({ startClient, server }, testInfo) => { + const { client } = await startClient({ + args: extraArgs, + clientName: 'Visual Studio Code', // Simulate VS Code client, roots only work with it + roots: [ + { + name: 'test', + uri: 'file://' + p.replace(/\\/g, '/'), + } + ], + }); + + await client.callTool({ + name: 'browser_navigate', + arguments: { url: server.HELLO_WORLD }, + }); + + const hash = createHash(p); + const [file] = await fs.promises.readdir(testInfo.outputPath('ms-playwright')); + expect(file).toContain(hash); + }); + + + test('check that trace is saved in workspace', async ({ startClient, server, mcpMode }, testInfo) => { + const rootPath = testInfo.outputPath('workspace'); + const { client } = await startClient({ + args: ['--save-trace', ...extraArgs], + clientName: 'Visual Studio Code - Insiders', // Simulate VS Code client, roots only work with it + roots: [ + { + name: 'workspace', + uri: pathToFileURL(rootPath).toString(), + }, + ], + }); + + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { url: server.HELLO_WORLD }, + })).toHaveResponse({ + code: expect.stringContaining(`page.goto('http://localhost`), + }); + + const [file] = await fs.promises.readdir(path.join(rootPath, '.playwright-mcp')); + expect(file).toContain('traces'); + }); }); - - await client.callTool({ - name: 'browser_navigate', - arguments: { url: server.HELLO_WORLD }, - }); - - const hash = createHash(p); - const [file] = await fs.promises.readdir(testInfo.outputPath('ms-playwright')); - expect(file).toContain(hash); -}); - - -test('check that trace is saved in workspace', async ({ startClient, server, mcpMode }, testInfo) => { - const rootPath = testInfo.outputPath('workspace'); - const { client } = await startClient({ - args: ['--save-trace'], - clientName: 'Visual Studio Code - Insiders', // Simulate VS Code client, roots only work with it - roots: [ - { - name: 'workspace', - uri: pathToFileURL(rootPath).toString(), - }, - ], - }); - - expect(await client.callTool({ - name: 'browser_navigate', - arguments: { url: server.HELLO_WORLD }, - })).toHaveResponse({ - code: expect.stringContaining(`page.goto('http://localhost`), - }); - - const [file] = await fs.promises.readdir(path.join(rootPath, '.playwright-mcp')); - expect(file).toContain('traces'); -}); +}