chore: refactor initialize (#812)

This commit is contained in:
Pavel Feldman
2025-08-01 13:06:36 -07:00
committed by GitHub
parent 7c07cc86eb
commit ffe0117456
3 changed files with 44 additions and 21 deletions

View File

@@ -43,9 +43,16 @@ export class BrowserServerBackend implements ServerBackend {
this._tools = filteredTools(config);
}
async initialize() {
async initialize(server: mcpServer.Server): Promise<void> {
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined;
this._context = new Context(this._tools, this._config, this._browserContextFactory, this._sessionLog);
this._context = new Context({
tools: this._tools,
config: this._config,
browserContextFactory: this._browserContextFactory,
sessionLog: this._sessionLog,
clientVersion: server.getClientVersion(),
capabilities: server.getClientCapabilities() as mcpServer.ClientCapabilities,
});
}
tools(): mcpServer.ToolSchema<any>[] {
@@ -69,10 +76,6 @@ export class BrowserServerBackend implements ServerBackend {
return response.serialize();
}
serverInitialized(version: mcpServer.ClientVersion | undefined) {
this._context!.clientVersion = version;
}
serverClosed() {
void this._context!.dispose().catch(logUnhandledError);
}

View File

@@ -20,6 +20,7 @@ import * as playwright from 'playwright';
import { logUnhandledError } from './log.js';
import { Tab } from './tab.js';
import type * as mcpServer from './mcp/server.js';
import type { Tool } from './tools/tool.js';
import type { FullConfig } from './config.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
@@ -28,6 +29,14 @@ import type { SessionLog } from './sessionLog.js';
const testDebug = debug('pw:mcp:test');
type ContextOptions = {
tools: Tool[];
config: FullConfig;
browserContextFactory: BrowserContextFactory;
sessionLog: SessionLog | undefined;
clientVersion: { name: string; version: string; } | undefined;
capabilities: mcpServer.ClientCapabilities | undefined;
};
export class Context {
readonly tools: Tool[];
readonly config: FullConfig;
@@ -36,19 +45,21 @@ export class Context {
private _browserContextFactory: BrowserContextFactory;
private _tabs: Tab[] = [];
private _currentTab: Tab | undefined;
clientVersion: { name: string; version: string; } | undefined;
private _clientVersion: { name: string; version: string; } | undefined;
private _clientCapabilities: mcpServer.ClientCapabilities;
private static _allContexts: Set<Context> = new Set();
private _closeBrowserContextPromise: Promise<void> | undefined;
private _isRunningTool: boolean = false;
private _abortController = new AbortController();
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) {
this.tools = tools;
this.config = config;
this._browserContextFactory = browserContextFactory;
this.sessionLog = sessionLog;
constructor(options: ContextOptions) {
this.tools = options.tools;
this.config = options.config;
this.sessionLog = options.sessionLog;
this._browserContextFactory = options.browserContextFactory;
this._clientVersion = options.clientVersion;
this._clientCapabilities = options.capabilities || {};
testDebug('create context');
Context._allContexts.add(this);
}
@@ -188,7 +199,7 @@ export class Context {
if (this._closeBrowserContextPromise)
throw new Error('Another browser context is being closed.');
// TODO: move to the browser context factory to make it based on isolation mode.
const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal);
const result = await this._browserContextFactory.createContext(this._clientVersion!, this._abortController.signal);
const { browserContext } = result;
await this._setupRequestInterception(browserContext);
if (this.sessionLog)

View File

@@ -18,11 +18,18 @@ import { z } from 'zod';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { ManualPromise } from '../manualPromise.js';
import { logUnhandledError } from '../log.js';
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export type ClientVersion = Implementation;
export type ClientCapabilities = {
roots?: {
listRoots?: boolean
};
};
export type ToolResponse = {
content: (TextContent | ImageContent)[];
@@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise<ToolRespons
export interface ServerBackend {
name: string;
version: string;
initialize?(): Promise<void>;
initialize?(server: Server): Promise<void>;
tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverInitialized?(version: ClientVersion | undefined): void;
serverClosed?(): void;
}
@@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
const backend = serverBackendFactory();
await backend.initialize?.();
const server = createServer(backend, runHeartbeat);
await server.connect(transport);
}
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
const initializedPromise = new ManualPromise<void>();
const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: {
tools: {},
@@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
let heartbeatRunning = false;
server.setRequestHandler(CallToolRequestSchema, async request => {
await initializedPromise;
if (runHeartbeat && !heartbeatRunning) {
heartbeatRunning = true;
startHeartbeat(server);
@@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
return errorResult(String(error));
}
});
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
addServerListener(server, 'initialized', () => {
backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
});
addServerListener(server, 'close', () => backend.serverClosed?.());
return server;
}