diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index 4a6b1f6..5dcfd79 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -46,11 +46,9 @@ export class BrowserServerBackend implements ServerBackend { } async initialize(server: mcpServer.Server): Promise { - const capabilities = server.getClientCapabilities() as mcpServer.ClientCapabilities; + const capabilities = server.getClientCapabilities(); let rootPath: string | undefined; - if (capabilities.roots && ( - server.getClientVersion()?.name === 'Visual Studio Code' || - server.getClientVersion()?.name === 'Visual Studio Code - Insiders')) { + if (capabilities?.roots) { const { roots } = await server.listRoots(); const firstRootUri = roots[0]?.uri; const url = firstRootUri ? new URL(firstRootUri) : undefined; @@ -89,6 +87,6 @@ export class BrowserServerBackend implements ServerBackend { } serverClosed() { - void this._context!.dispose().catch(logUnhandledError); + void this._context?.dispose().catch(logUnhandledError); } } diff --git a/src/inProcessClient.ts b/src/inProcessMcpFactrory.ts similarity index 60% rename from src/inProcessClient.ts rename to src/inProcessMcpFactrory.ts index 390954c..c29ceaf 100644 --- a/src/inProcessClient.ts +++ b/src/inProcessMcpFactrory.ts @@ -15,18 +15,16 @@ */ -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { ListRootsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { BrowserContextFactory } from './browserContextFactory.js'; import { BrowserServerBackend } from './browserServerBackend.js'; import { InProcessTransport } from './mcp/inProcessTransport.js'; import * as mcpServer from './mcp/server.js'; -import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { FullConfig } from './config.js'; -import type { ClientFactory } from './mcp/proxyBackend.js'; +import type { MCPFactory } from './mcp/proxyBackend.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; -export class InProcessClientFactory implements ClientFactory { +export class InProcessMCPFactory implements MCPFactory { name: string; description: string; @@ -40,21 +38,8 @@ export class InProcessClientFactory implements ClientFactory { this._config = config; } - async create(server: Server): Promise { - const client = new Client(server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }); - const clientCapabilities = server.getClientCapabilities(); - if (clientCapabilities) - client.registerCapabilities(clientCapabilities); - - if (clientCapabilities?.roots) { - client.setRequestHandler(ListRootsRequestSchema, async () => { - return await server.listRoots(); - }); - } - + async create(): Promise { const delegate = mcpServer.createServer(new BrowserServerBackend(this._config, this._contextFactory), false); - await client.connect(new InProcessTransport(delegate)); - await client.ping(); - return client; + return new InProcessTransport(delegate); } } diff --git a/src/mcp/proxyBackend.ts b/src/mcp/proxyBackend.ts index 1128818..1372ca4 100644 --- a/src/mcp/proxyBackend.ts +++ b/src/mcp/proxyBackend.ts @@ -14,48 +14,51 @@ * limitations under the License. */ -import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { z } from 'zod'; import { zodToJsonSchema } from 'zod-to-json-schema'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { ListRootsRequestSchema, PingRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { logUnhandledError } from '../utils/log.js'; import { packageJSON } from '../utils/package.js'; -import { ToolDefinition, ServerBackend, ToolResponse } from './server.js'; -import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; + +import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import type { ToolDefinition, ServerBackend, ToolResponse } from './server.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; type NonEmptyArray = [T, ...T[]]; -export type ClientFactory = { +export type MCPFactory = { name: string; description: string; - create(server: Server): Promise; + create(): Promise; }; -export type ClientFactoryList = NonEmptyArray; +export type MCPFactoryList = NonEmptyArray; export class ProxyBackend implements ServerBackend { name = 'Playwright MCP Client Switcher'; version = packageJSON.version; - private _clientFactories: ClientFactoryList; + private _mcpFactories: MCPFactoryList; private _currentClient: Client | undefined; private _contextSwitchTool: ToolDefinition; private _tools: ToolDefinition[] = []; private _server: Server | undefined; - constructor(clientFactories: ClientFactoryList) { - this._clientFactories = clientFactories; + constructor(clientFactories: MCPFactoryList) { + this._mcpFactories = clientFactories; this._contextSwitchTool = this._defineContextSwitchTool(); } async initialize(server: Server): Promise { this._server = server; - await this._setCurrentClient(this._clientFactories[0]); + await this._setCurrentClient(this._mcpFactories[0]); } tools(): ToolDefinition[] { - if (this._clientFactories.length === 1) + if (this._mcpFactories.length === 1) return this._tools; return [ ...this._tools, @@ -79,7 +82,7 @@ export class ProxyBackend implements ServerBackend { private async _callContextSwitchTool(params: any): Promise { try { - const factory = this._clientFactories.find(factory => factory.name === params.name); + const factory = this._mcpFactories.find(factory => factory.name === params.name); if (!factory) throw new Error('Unknown connection method: ' + params.name); @@ -100,10 +103,10 @@ export class ProxyBackend implements ServerBackend { name: 'browser_connect', description: [ 'Connect to a browser using one of the available methods:', - ...this._clientFactories.map(factory => `- "${factory.name}": ${factory.description}`), + ...this._mcpFactories.map(factory => `- "${factory.name}": ${factory.description}`), ].join('\n'), inputSchema: zodToJsonSchema(z.object({ - name: z.enum(this._clientFactories.map(factory => factory.name) as [string, ...string[]]).default(this._clientFactories[0].name).describe('The method to use to connect to the browser'), + name: z.enum(this._mcpFactories.map(factory => factory.name) as [string, ...string[]]).default(this._mcpFactories[0].name).describe('The method to use to connect to the browser'), }), { strictUnions: true }) as ToolDefinition['inputSchema'], annotations: { title: 'Connect to a browser context', @@ -113,9 +116,32 @@ export class ProxyBackend implements ServerBackend { }; } - private async _setCurrentClient(factory: ClientFactory) { + private async _setCurrentClient(factory: MCPFactory) { await this._currentClient?.close(); - this._currentClient = await factory.create(this._server!); + this._currentClient = undefined; + + const client = new Client({ name: 'Playwright MCP Proxy', version: packageJSON.version }); + client.registerCapabilities({ + roots: { + listRoots: true, + }, + }); + client.setRequestHandler(ListRootsRequestSchema, async () => { + const clientName = this._server!.getClientVersion()?.name; + if (this._server!.getClientCapabilities()?.roots && ( + clientName === 'Visual Studio Code' || + clientName === 'Visual Studio Code - Insiders')) { + const { roots } = await this._server!.listRoots(); + return { roots }; + } + return { roots: [] }; + }); + client.setRequestHandler(PingRequestSchema, () => ({})); + + const transport = await factory.create(); + await client.connect(transport); + + this._currentClient = client; const tools = await this._currentClient.listTools(); this._tools = tools.tools; } diff --git a/src/program.ts b/src/program.ts index 2214c47..ad8fa33 100644 --- a/src/program.ts +++ b/src/program.ts @@ -23,12 +23,11 @@ import { Context } from './context.js'; import { contextFactory } from './browserContextFactory.js'; import { runLoopTools } from './loopTools/main.js'; import { ProxyBackend } from './mcp/proxyBackend.js'; -import { InProcessClientFactory } from './inProcessClient.js'; +import { InProcessMCPFactory } from './inProcessMcpFactrory.js'; import { BrowserServerBackend } from './browserServerBackend.js'; import { ExtensionContextFactory } from './extension/extensionContextFactory.js'; -import type { ClientFactoryList } from './mcp/proxyBackend.js'; -import type { ServerBackendFactory } from './mcp/server.js'; +import type { MCPFactoryList } from './mcp/proxyBackend.js'; import type { FullConfig } from './config.js'; program @@ -84,18 +83,13 @@ program return; } - let serverBackendFactory: ServerBackendFactory; const browserContextFactory = contextFactory(config); - if (options.connectTool) { - const factories: ClientFactoryList = [ - new InProcessClientFactory(browserContextFactory, config), - new InProcessClientFactory(createExtensionContextFactory(config), config), - ]; - serverBackendFactory = () => new ProxyBackend(factories); - } else { - serverBackendFactory = () => new BrowserServerBackend(config, browserContextFactory); - } - await mcpTransport.start(serverBackendFactory, config.server); + const factories: MCPFactoryList = [ + new InProcessMCPFactory(browserContextFactory, config), + ]; + if (options.connectTool) + factories.push(new InProcessMCPFactory(createExtensionContextFactory(config), config)); + await mcpTransport.start(() => new ProxyBackend(factories), config.server); }); function setupExitWatchdog() {