chore: refactor initialize (#812)

This commit is contained in:
Pavel Feldman
2025-08-01 13:06:36 -07:00
committed by GitHub
parent 7c07cc86eb
commit ffe0117456
3 changed files with 44 additions and 21 deletions

View File

@@ -43,9 +43,16 @@ export class BrowserServerBackend implements ServerBackend {
this._tools = filteredTools(config); this._tools = filteredTools(config);
} }
async initialize() { async initialize(server: mcpServer.Server): Promise<void> {
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined; this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined;
this._context = new Context(this._tools, this._config, this._browserContextFactory, this._sessionLog); this._context = new Context({
tools: this._tools,
config: this._config,
browserContextFactory: this._browserContextFactory,
sessionLog: this._sessionLog,
clientVersion: server.getClientVersion(),
capabilities: server.getClientCapabilities() as mcpServer.ClientCapabilities,
});
} }
tools(): mcpServer.ToolSchema<any>[] { tools(): mcpServer.ToolSchema<any>[] {
@@ -69,10 +76,6 @@ export class BrowserServerBackend implements ServerBackend {
return response.serialize(); return response.serialize();
} }
serverInitialized(version: mcpServer.ClientVersion | undefined) {
this._context!.clientVersion = version;
}
serverClosed() { serverClosed() {
void this._context!.dispose().catch(logUnhandledError); void this._context!.dispose().catch(logUnhandledError);
} }

View File

@@ -20,6 +20,7 @@ import * as playwright from 'playwright';
import { logUnhandledError } from './log.js'; import { logUnhandledError } from './log.js';
import { Tab } from './tab.js'; import { Tab } from './tab.js';
import type * as mcpServer from './mcp/server.js';
import type { Tool } from './tools/tool.js'; import type { Tool } from './tools/tool.js';
import type { FullConfig } from './config.js'; import type { FullConfig } from './config.js';
import type { BrowserContextFactory } from './browserContextFactory.js'; import type { BrowserContextFactory } from './browserContextFactory.js';
@@ -28,6 +29,14 @@ import type { SessionLog } from './sessionLog.js';
const testDebug = debug('pw:mcp:test'); const testDebug = debug('pw:mcp:test');
type ContextOptions = {
tools: Tool[];
config: FullConfig;
browserContextFactory: BrowserContextFactory;
sessionLog: SessionLog | undefined;
clientVersion: { name: string; version: string; } | undefined;
capabilities: mcpServer.ClientCapabilities | undefined;
};
export class Context { export class Context {
readonly tools: Tool[]; readonly tools: Tool[];
readonly config: FullConfig; readonly config: FullConfig;
@@ -36,19 +45,21 @@ export class Context {
private _browserContextFactory: BrowserContextFactory; private _browserContextFactory: BrowserContextFactory;
private _tabs: Tab[] = []; private _tabs: Tab[] = [];
private _currentTab: Tab | undefined; private _currentTab: Tab | undefined;
private _clientVersion: { name: string; version: string; } | undefined;
clientVersion: { name: string; version: string; } | undefined; private _clientCapabilities: mcpServer.ClientCapabilities;
private static _allContexts: Set<Context> = new Set(); private static _allContexts: Set<Context> = new Set();
private _closeBrowserContextPromise: Promise<void> | undefined; private _closeBrowserContextPromise: Promise<void> | undefined;
private _isRunningTool: boolean = false; private _isRunningTool: boolean = false;
private _abortController = new AbortController(); private _abortController = new AbortController();
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) { constructor(options: ContextOptions) {
this.tools = tools; this.tools = options.tools;
this.config = config; this.config = options.config;
this._browserContextFactory = browserContextFactory; this.sessionLog = options.sessionLog;
this.sessionLog = sessionLog; this._browserContextFactory = options.browserContextFactory;
this._clientVersion = options.clientVersion;
this._clientCapabilities = options.capabilities || {};
testDebug('create context'); testDebug('create context');
Context._allContexts.add(this); Context._allContexts.add(this);
} }
@@ -188,7 +199,7 @@ export class Context {
if (this._closeBrowserContextPromise) if (this._closeBrowserContextPromise)
throw new Error('Another browser context is being closed.'); throw new Error('Another browser context is being closed.');
// TODO: move to the browser context factory to make it based on isolation mode. // TODO: move to the browser context factory to make it based on isolation mode.
const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal); const result = await this._browserContextFactory.createContext(this._clientVersion!, this._abortController.signal);
const { browserContext } = result; const { browserContext } = result;
await this._setupRequestInterception(browserContext); await this._setupRequestInterception(browserContext);
if (this.sessionLog) if (this.sessionLog)

View File

@@ -18,11 +18,18 @@ import { z } from 'zod';
import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema'; import { zodToJsonSchema } from 'zod-to-json-schema';
import { ManualPromise } from '../manualPromise.js';
import { logUnhandledError } from '../log.js';
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js'; import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export type ClientVersion = Implementation; export type ClientCapabilities = {
roots?: {
listRoots?: boolean
};
};
export type ToolResponse = { export type ToolResponse = {
content: (TextContent | ImageContent)[]; content: (TextContent | ImageContent)[];
@@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise<ToolRespons
export interface ServerBackend { export interface ServerBackend {
name: string; name: string;
version: string; version: string;
initialize?(): Promise<void>; initialize?(server: Server): Promise<void>;
tools(): ToolSchema<any>[]; tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>; callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverInitialized?(version: ClientVersion | undefined): void;
serverClosed?(): void; serverClosed?(): void;
} }
@@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) { export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
const backend = serverBackendFactory(); const backend = serverBackendFactory();
await backend.initialize?.();
const server = createServer(backend, runHeartbeat); const server = createServer(backend, runHeartbeat);
await server.connect(transport); await server.connect(transport);
} }
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server { export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
const initializedPromise = new ManualPromise<void>();
const server = new Server({ name: backend.name, version: backend.version }, { const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: { capabilities: {
tools: {}, tools: {},
@@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
let heartbeatRunning = false; let heartbeatRunning = false;
server.setRequestHandler(CallToolRequestSchema, async request => { server.setRequestHandler(CallToolRequestSchema, async request => {
await initializedPromise;
if (runHeartbeat && !heartbeatRunning) { if (runHeartbeat && !heartbeatRunning) {
heartbeatRunning = true; heartbeatRunning = true;
startHeartbeat(server); startHeartbeat(server);
@@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
return errorResult(String(error)); return errorResult(String(error));
} }
}); });
addServerListener(server, 'initialized', () => {
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion())); backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
});
addServerListener(server, 'close', () => backend.serverClosed?.()); addServerListener(server, 'close', () => backend.serverClosed?.());
return server; return server;
} }