diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index a8cd7c3..0a27e2a 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -43,9 +43,16 @@ export class BrowserServerBackend implements ServerBackend { this._tools = filteredTools(config); } - async initialize() { + async initialize(server: mcpServer.Server): Promise { this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined; - this._context = new Context(this._tools, this._config, this._browserContextFactory, this._sessionLog); + this._context = new Context({ + tools: this._tools, + config: this._config, + browserContextFactory: this._browserContextFactory, + sessionLog: this._sessionLog, + clientVersion: server.getClientVersion(), + capabilities: server.getClientCapabilities() as mcpServer.ClientCapabilities, + }); } tools(): mcpServer.ToolSchema[] { @@ -69,10 +76,6 @@ export class BrowserServerBackend implements ServerBackend { return response.serialize(); } - serverInitialized(version: mcpServer.ClientVersion | undefined) { - this._context!.clientVersion = version; - } - serverClosed() { void this._context!.dispose().catch(logUnhandledError); } diff --git a/src/context.ts b/src/context.ts index 5dfc5a1..4c1018e 100644 --- a/src/context.ts +++ b/src/context.ts @@ -20,6 +20,7 @@ import * as playwright from 'playwright'; import { logUnhandledError } from './log.js'; import { Tab } from './tab.js'; +import type * as mcpServer from './mcp/server.js'; import type { Tool } from './tools/tool.js'; import type { FullConfig } from './config.js'; import type { BrowserContextFactory } from './browserContextFactory.js'; @@ -28,6 +29,14 @@ import type { SessionLog } from './sessionLog.js'; const testDebug = debug('pw:mcp:test'); +type ContextOptions = { + tools: Tool[]; + config: FullConfig; + browserContextFactory: BrowserContextFactory; + sessionLog: SessionLog | undefined; + clientVersion: { name: string; version: string; } | undefined; + capabilities: mcpServer.ClientCapabilities | undefined; +}; export class Context { readonly tools: Tool[]; readonly config: FullConfig; @@ -36,19 +45,21 @@ export class Context { private _browserContextFactory: BrowserContextFactory; private _tabs: Tab[] = []; private _currentTab: Tab | undefined; - - clientVersion: { name: string; version: string; } | undefined; + private _clientVersion: { name: string; version: string; } | undefined; + private _clientCapabilities: mcpServer.ClientCapabilities; private static _allContexts: Set = new Set(); private _closeBrowserContextPromise: Promise | undefined; private _isRunningTool: boolean = false; private _abortController = new AbortController(); - constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) { - this.tools = tools; - this.config = config; - this._browserContextFactory = browserContextFactory; - this.sessionLog = sessionLog; + constructor(options: ContextOptions) { + this.tools = options.tools; + this.config = options.config; + this.sessionLog = options.sessionLog; + this._browserContextFactory = options.browserContextFactory; + this._clientVersion = options.clientVersion; + this._clientCapabilities = options.capabilities || {}; testDebug('create context'); Context._allContexts.add(this); } @@ -188,7 +199,7 @@ export class Context { if (this._closeBrowserContextPromise) throw new Error('Another browser context is being closed.'); // TODO: move to the browser context factory to make it based on isolation mode. - const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal); + const result = await this._browserContextFactory.createContext(this._clientVersion!, this._abortController.signal); const { browserContext } = result; await this._setupRequestInterception(browserContext); if (this.sessionLog) diff --git a/src/mcp/server.ts b/src/mcp/server.ts index a627022..18c3144 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -18,11 +18,18 @@ import { z } from 'zod'; import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { zodToJsonSchema } from 'zod-to-json-schema'; +import { ManualPromise } from '../manualPromise.js'; +import { logUnhandledError } from '../log.js'; -import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js'; +import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; +export type { Server } from '@modelcontextprotocol/sdk/server/index.js'; -export type ClientVersion = Implementation; +export type ClientCapabilities = { + roots?: { + listRoots?: boolean + }; +}; export type ToolResponse = { content: (TextContent | ImageContent)[]; @@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise; + initialize?(server: Server): Promise; tools(): ToolSchema[]; callTool(schema: ToolSchema, parsedArguments: any): Promise; - serverInitialized?(version: ClientVersion | undefined): void; serverClosed?(): void; } @@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend; export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) { const backend = serverBackendFactory(); - await backend.initialize?.(); const server = createServer(backend, runHeartbeat); await server.connect(transport); } export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server { + const initializedPromise = new ManualPromise(); const server = new Server({ name: backend.name, version: backend.version }, { capabilities: { tools: {}, @@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser let heartbeatRunning = false; server.setRequestHandler(CallToolRequestSchema, async request => { + await initializedPromise; + if (runHeartbeat && !heartbeatRunning) { heartbeatRunning = true; startHeartbeat(server); @@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser return errorResult(String(error)); } }); - - addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion())); + addServerListener(server, 'initialized', () => { + backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError); + }); addServerListener(server, 'close', () => backend.serverClosed?.()); return server; }