From f010164bf120129ab702af588b78bce99d39e221 Mon Sep 17 00:00:00 2001 From: Yury Semikhatsky Date: Mon, 11 Aug 2025 14:16:43 -0700 Subject: [PATCH] chore: mcp backend switcher (#854) --- extension/tests/extension.spec.ts | 4 +- src/browserServerBackend.ts | 57 +------------ src/extension/main.ts | 13 ++- src/inProcessClient.ts | 52 ++++++++++++ src/index.ts | 2 +- src/loopTools/context.ts | 2 +- src/loopTools/main.ts | 3 +- src/mcp/proxyBackend.ts | 131 ++++++++++++++++++++++++++++++ src/mcp/server.ts | 7 +- src/program.ts | 23 ++++-- 10 files changed, 224 insertions(+), 70 deletions(-) create mode 100644 src/inProcessClient.ts create mode 100644 src/mcp/proxyBackend.ts diff --git a/extension/tests/extension.spec.ts b/extension/tests/extension.spec.ts index 36b75c2..5de2c19 100644 --- a/extension/tests/extension.spec.ts +++ b/extension/tests/extension.spec.ts @@ -78,7 +78,7 @@ test('navigate with extension', async ({ browserWithExtension, startClient, serv expect(await client.callTool({ name: 'browser_connect', arguments: { - method: 'extension' + name: 'extension' } })).toHaveResponse({ result: 'Successfully changed connection method.', @@ -123,7 +123,7 @@ test('snapshot of an existing page', async ({ browserWithExtension, startClient, expect(await client.callTool({ name: 'browser_connect', arguments: { - method: 'extension' + name: 'extension' } })).toHaveResponse({ result: 'Successfully changed connection method.', diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index 3a5ada4..0c20a37 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -15,7 +15,6 @@ */ import { fileURLToPath } from 'url'; -import { z } from 'zod'; import { FullConfig } from './config.js'; import { Context } from './context.js'; import { logUnhandledError } from './log.js'; @@ -23,17 +22,12 @@ import { Response } from './response.js'; import { SessionLog } from './sessionLog.js'; import { filteredTools } from './tools.js'; import { packageJSON } from './package.js'; -import { defineTool } from './tools/tool.js'; import type { Tool } from './tools/tool.js'; import type { BrowserContextFactory } from './browserContextFactory.js'; import type * as mcpServer from './mcp/server.js'; import type { ServerBackend } from './mcp/server.js'; -type NonEmptyArray = [T, ...T[]]; - -export type FactoryList = NonEmptyArray; - export class BrowserServerBackend implements ServerBackend { name = 'Playwright'; version = packageJSON.version; @@ -44,12 +38,10 @@ export class BrowserServerBackend implements ServerBackend { private _config: FullConfig; private _browserContextFactory: BrowserContextFactory; - constructor(config: FullConfig, factories: FactoryList) { + constructor(config: FullConfig, factory: BrowserContextFactory) { this._config = config; - this._browserContextFactory = factories[0]; + this._browserContextFactory = factory; this._tools = filteredTools(config); - if (factories.length > 1) - this._tools.push(this._defineContextSwitchTool(factories)); } async initialize(server: mcpServer.Server): Promise { @@ -77,8 +69,9 @@ export class BrowserServerBackend implements ServerBackend { return this._tools.map(tool => tool.schema); } - async callTool(schema: mcpServer.ToolSchema, parsedArguments: any) { + async callTool(schema: mcpServer.ToolSchema, rawArguments: any) { const context = this._context!; + const parsedArguments = schema.inputSchema.parse(rawArguments || {}); const response = new Response(context, schema.name, parsedArguments); const tool = this._tools.find(tool => tool.schema.name === schema.name)!; context.setRunningTool(true); @@ -97,46 +90,4 @@ export class BrowserServerBackend implements ServerBackend { serverClosed() { void this._context!.dispose().catch(logUnhandledError); } - - private _defineContextSwitchTool(factories: FactoryList): Tool { - const self = this; - return defineTool({ - capability: 'core', - - schema: { - name: 'browser_connect', - title: 'Connect to a browser context', - description: [ - 'Connect to a browser using one of the available methods:', - ...factories.map(factory => `- "${factory.name}": ${factory.description}`), - ].join('\n'), - inputSchema: z.object({ - method: z.enum(factories.map(factory => factory.name) as [string, ...string[]]).default(factories[0].name).describe('The method to use to connect to the browser'), - }), - type: 'readOnly', - }, - - async handle(context, params, response) { - const factory = factories.find(factory => factory.name === params.method); - if (!factory) { - response.addError('Unknown connection method: ' + params.method); - return; - } - await self._setContextFactory(factory); - response.addResult('Successfully changed connection method.'); - } - }); - } - - private async _setContextFactory(newFactory: BrowserContextFactory) { - if (this._context) { - const options = { - ...this._context.options, - browserContextFactory: newFactory, - }; - await this._context.dispose(); - this._context = new Context(options); - } - this._browserContextFactory = newFactory; - } } diff --git a/src/extension/main.ts b/src/extension/main.ts index 4a209a6..f9fa177 100644 --- a/src/extension/main.ts +++ b/src/extension/main.ts @@ -16,16 +16,23 @@ import { ExtensionContextFactory } from './extensionContextFactory.js'; import { BrowserServerBackend } from '../browserServerBackend.js'; +import { InProcessClientFactory } from '../inProcessClient.js'; import * as mcpTransport from '../mcp/transport.js'; import type { FullConfig } from '../config.js'; +import type { ClientFactory } from '../mcp/proxyBackend.js'; export async function runWithExtension(config: FullConfig) { - const contextFactory = new ExtensionContextFactory(config.browser.launchOptions.channel || 'chrome', config.browser.userDataDir); - const serverBackendFactory = () => new BrowserServerBackend(config, [contextFactory]); + const contextFactory = createExtensionContextFactory(config); + const serverBackendFactory = () => new BrowserServerBackend(config, contextFactory); await mcpTransport.start(serverBackendFactory, config.server); } -export function createExtensionContextFactory(config: FullConfig) { +export function createExtensionClientFactory(config: FullConfig): ClientFactory { + return new InProcessClientFactory(createExtensionContextFactory(config), config); +} + + +function createExtensionContextFactory(config: FullConfig) { return new ExtensionContextFactory(config.browser.launchOptions.channel || 'chrome', config.browser.userDataDir); } diff --git a/src/inProcessClient.ts b/src/inProcessClient.ts new file mode 100644 index 0000000..448e1ac --- /dev/null +++ b/src/inProcessClient.ts @@ -0,0 +1,52 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { BrowserContextFactory } from './browserContextFactory.js'; +import { BrowserServerBackend } from './browserServerBackend.js'; +import { InProcessTransport } from './mcp/inProcessTransport.js'; +import * as mcpServer from './mcp/server.js'; +import { packageJSON } from './package.js'; + +import type { FullConfig } from './config.js'; +import type { ClientFactory } from './mcp/proxyBackend.js'; + +export class InProcessClientFactory implements ClientFactory { + name: string; + description: string; + + private _contextFactory: BrowserContextFactory; + private _config: FullConfig; + + constructor(contextFactory: BrowserContextFactory, config: FullConfig) { + this.name = contextFactory.name; + this.description = contextFactory.description; + this._contextFactory = contextFactory; + this._config = config; + } + + async create(): Promise { + const client = new Client({ + name: this.name, + version: packageJSON.version + }); + const server = mcpServer.createServer(new BrowserServerBackend(this._config, this._contextFactory), false); + await client.connect(new InProcessTransport(server)); + await client.ping(); + return client; + } +} diff --git a/src/index.ts b/src/index.ts index 2d181b1..86a9280 100644 --- a/src/index.ts +++ b/src/index.ts @@ -27,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise): Promise { const config = await resolveConfig(userConfig); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config); - return mcpServer.createServer(new BrowserServerBackend(config, [factory]), false); + return mcpServer.createServer(new BrowserServerBackend(config, factory), false); } class SimpleBrowserContextFactory implements BrowserContextFactory { diff --git a/src/loopTools/context.ts b/src/loopTools/context.ts index 9e52577..732af07 100644 --- a/src/loopTools/context.ts +++ b/src/loopTools/context.ts @@ -46,7 +46,7 @@ export class Context { static async create(config: FullConfig) { const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' }); const browserContextFactory = contextFactory(config); - const server = mcpServer.createServer(new BrowserServerBackend(config, [browserContextFactory]), false); + const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory), false); await client.connect(new InProcessTransport(server)); await client.ping(); return new Context(config, client); diff --git a/src/loopTools/main.ts b/src/loopTools/main.ts index ded5b88..fc788aa 100644 --- a/src/loopTools/main.ts +++ b/src/loopTools/main.ts @@ -52,8 +52,9 @@ class LoopToolsServerBackend implements ServerBackend { return this._tools.map(tool => tool.schema); } - async callTool(schema: mcpServer.ToolSchema, parsedArguments: any): Promise { + async callTool(schema: mcpServer.ToolSchema, rawArguments: any): Promise { const tool = this._tools.find(tool => tool.schema.name === schema.name)!; + const parsedArguments = schema.inputSchema.parse(rawArguments || {}); return await tool.handle(this._context!, parsedArguments); } diff --git a/src/mcp/proxyBackend.ts b/src/mcp/proxyBackend.ts new file mode 100644 index 0000000..c2e6eef --- /dev/null +++ b/src/mcp/proxyBackend.ts @@ -0,0 +1,131 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; + +import { z } from 'zod'; +import { ServerBackend, ToolResponse, ToolSchema } from './server.js'; +import { defineTool, Tool } from '../tools/tool.js'; +import { packageJSON } from '../package.js'; +import { logUnhandledError } from '../log.js'; + +import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; + +type NonEmptyArray = [T, ...T[]]; + +export type ClientFactory = { + name: string; + description: string; + create(): Promise; +}; + +export type ClientFactoryList = NonEmptyArray; + +export class ProxyBackend implements ServerBackend { + name = 'Playwright MCP Client Switcher'; + version = packageJSON.version; + + private _clientFactories: ClientFactoryList; + private _currentClient: Client | undefined; + private _contextSwitchTool: Tool; + private _tools: ToolSchema[] = []; + + constructor(clientFactories: ClientFactoryList) { + this._clientFactories = clientFactories; + this._contextSwitchTool = this._defineContextSwitchTool(); + } + + async initialize(server: Server): Promise { + await this._setCurrentClient(this._clientFactories[0]); + } + + tools(): ToolSchema[] { + if (this._clientFactories.length === 1) + return this._tools; + return [ + ...this._tools, + this._contextSwitchTool.schema, + ]; + } + + async callTool(schema: ToolSchema, rawArguments: any): Promise { + if (schema.name === this._contextSwitchTool.schema.name) + return this._callContextSwitchTool(rawArguments); + const result = await this._currentClient!.callTool({ + name: schema.name, + arguments: rawArguments, + }); + return result as unknown as ToolResponse; + } + + serverClosed?(): void { + void this._currentClient?.close().catch(logUnhandledError); + } + + private async _callContextSwitchTool(params: any): Promise { + try { + const factory = this._clientFactories.find(factory => factory.name === params.name); + if (!factory) + throw new Error('Unknown connection method: ' + params.name); + + await this._setCurrentClient(factory); + return { + content: [{ type: 'text', text: '### Result\nSuccessfully changed connection method.\n' }], + }; + } catch (error) { + return { + content: [{ type: 'text', text: `### Result\nError: ${error}\n` }], + isError: true, + }; + } + } + + private _defineContextSwitchTool(): Tool { + return defineTool({ + capability: 'core', + + schema: { + name: 'browser_connect', + title: 'Connect to a browser context', + description: [ + 'Connect to a browser using one of the available methods:', + ...this._clientFactories.map(factory => `- "${factory.name}": ${factory.description}`), + ].join('\n'), + inputSchema: z.object({ + name: z.enum(this._clientFactories.map(factory => factory.name) as [string, ...string[]]).default(this._clientFactories[0].name).describe('The method to use to connect to the browser'), + }), + type: 'readOnly', + }, + + async handle() { + throw new Error('Unreachable'); + } + }); + } + + private async _setCurrentClient(factory: ClientFactory) { + await this._currentClient?.close(); + this._currentClient = await factory.create(); + const tools = await this._currentClient.listTools(); + this._tools = tools.tools.map(tool => ({ + name: tool.name, + title: tool.title ?? '', + description: tool.description ?? '', + inputSchema: tool.inputSchema ?? z.object({}), + type: tool.annotations?.readOnlyHint ? 'readOnly' as const : 'destructive' as const, + })); + } +} diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 18c3144..5da54be 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -51,7 +51,7 @@ export interface ServerBackend { version: string; initialize?(server: Server): Promise; tools(): ToolSchema[]; - callTool(schema: ToolSchema, parsedArguments: any): Promise; + callTool(schema: ToolSchema, rawArguments: any): Promise; serverClosed?(): void; } @@ -71,8 +71,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser } }); - const tools = backend.tools(); server.setRequestHandler(ListToolsRequestSchema, async () => { + const tools = backend.tools(); return { tools: tools.map(tool => ({ name: tool.name, description: tool.description, @@ -99,12 +99,13 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser content: [{ type: 'text', text: '### Result\n' + messages.join('\n') }], isError: true, }); + const tools = backend.tools(); const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema; if (!tool) return errorResult(`Error: Tool "${request.params.name}" not found`); try { - return await backend.callTool(tool, tool.inputSchema.parse(request.params.arguments || {})); + return await backend.callTool(tool, request.params.arguments || {}); } catch (error) { return errorResult(String(error)); } diff --git a/src/program.ts b/src/program.ts index ae1f1d9..24ebbae 100644 --- a/src/program.ts +++ b/src/program.ts @@ -21,11 +21,16 @@ import { startTraceViewerServer } from 'playwright-core/lib/server'; import * as mcpTransport from './mcp/transport.js'; import { commaSeparatedList, resolveCLIConfig, semicolonSeparatedList } from './config.js'; import { packageJSON } from './package.js'; -import { createExtensionContextFactory, runWithExtension } from './extension/main.js'; -import { BrowserServerBackend, FactoryList } from './browserServerBackend.js'; +import { createExtensionClientFactory, runWithExtension } from './extension/main.js'; import { Context } from './context.js'; import { contextFactory } from './browserContextFactory.js'; import { runLoopTools } from './loopTools/main.js'; +import { ProxyBackend } from './mcp/proxyBackend.js'; +import { InProcessClientFactory } from './inProcessClient.js'; +import { BrowserServerBackend } from './browserServerBackend.js'; + +import type { ClientFactoryList } from './mcp/proxyBackend.js'; +import type { ServerBackendFactory } from './mcp/server.js'; program .version('Version ' + packageJSON.version) @@ -78,11 +83,17 @@ program return; } + let serverBackendFactory: ServerBackendFactory; const browserContextFactory = contextFactory(config); - const factories: FactoryList = [browserContextFactory]; - if (options.connectTool) - factories.push(createExtensionContextFactory(config)); - const serverBackendFactory = () => new BrowserServerBackend(config, factories); + if (options.connectTool) { + const factories: ClientFactoryList = [ + new InProcessClientFactory(browserContextFactory, config), + createExtensionClientFactory(config) + ]; + serverBackendFactory = () => new ProxyBackend(factories); + } else { + serverBackendFactory = () => new BrowserServerBackend(config, browserContextFactory); + } await mcpTransport.start(serverBackendFactory, config.server); if (config.saveTrace) {