chore: refactor initialize (#812)
This commit is contained in:
@@ -43,9 +43,16 @@ export class BrowserServerBackend implements ServerBackend {
|
||||
this._tools = filteredTools(config);
|
||||
}
|
||||
|
||||
async initialize() {
|
||||
async initialize(server: mcpServer.Server): Promise<void> {
|
||||
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined;
|
||||
this._context = new Context(this._tools, this._config, this._browserContextFactory, this._sessionLog);
|
||||
this._context = new Context({
|
||||
tools: this._tools,
|
||||
config: this._config,
|
||||
browserContextFactory: this._browserContextFactory,
|
||||
sessionLog: this._sessionLog,
|
||||
clientVersion: server.getClientVersion(),
|
||||
capabilities: server.getClientCapabilities() as mcpServer.ClientCapabilities,
|
||||
});
|
||||
}
|
||||
|
||||
tools(): mcpServer.ToolSchema<any>[] {
|
||||
@@ -69,10 +76,6 @@ export class BrowserServerBackend implements ServerBackend {
|
||||
return response.serialize();
|
||||
}
|
||||
|
||||
serverInitialized(version: mcpServer.ClientVersion | undefined) {
|
||||
this._context!.clientVersion = version;
|
||||
}
|
||||
|
||||
serverClosed() {
|
||||
void this._context!.dispose().catch(logUnhandledError);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import * as playwright from 'playwright';
|
||||
import { logUnhandledError } from './log.js';
|
||||
import { Tab } from './tab.js';
|
||||
|
||||
import type * as mcpServer from './mcp/server.js';
|
||||
import type { Tool } from './tools/tool.js';
|
||||
import type { FullConfig } from './config.js';
|
||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||
@@ -28,6 +29,14 @@ import type { SessionLog } from './sessionLog.js';
|
||||
|
||||
const testDebug = debug('pw:mcp:test');
|
||||
|
||||
type ContextOptions = {
|
||||
tools: Tool[];
|
||||
config: FullConfig;
|
||||
browserContextFactory: BrowserContextFactory;
|
||||
sessionLog: SessionLog | undefined;
|
||||
clientVersion: { name: string; version: string; } | undefined;
|
||||
capabilities: mcpServer.ClientCapabilities | undefined;
|
||||
};
|
||||
export class Context {
|
||||
readonly tools: Tool[];
|
||||
readonly config: FullConfig;
|
||||
@@ -36,19 +45,21 @@ export class Context {
|
||||
private _browserContextFactory: BrowserContextFactory;
|
||||
private _tabs: Tab[] = [];
|
||||
private _currentTab: Tab | undefined;
|
||||
|
||||
clientVersion: { name: string; version: string; } | undefined;
|
||||
private _clientVersion: { name: string; version: string; } | undefined;
|
||||
private _clientCapabilities: mcpServer.ClientCapabilities;
|
||||
|
||||
private static _allContexts: Set<Context> = new Set();
|
||||
private _closeBrowserContextPromise: Promise<void> | undefined;
|
||||
private _isRunningTool: boolean = false;
|
||||
private _abortController = new AbortController();
|
||||
|
||||
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) {
|
||||
this.tools = tools;
|
||||
this.config = config;
|
||||
this._browserContextFactory = browserContextFactory;
|
||||
this.sessionLog = sessionLog;
|
||||
constructor(options: ContextOptions) {
|
||||
this.tools = options.tools;
|
||||
this.config = options.config;
|
||||
this.sessionLog = options.sessionLog;
|
||||
this._browserContextFactory = options.browserContextFactory;
|
||||
this._clientVersion = options.clientVersion;
|
||||
this._clientCapabilities = options.capabilities || {};
|
||||
testDebug('create context');
|
||||
Context._allContexts.add(this);
|
||||
}
|
||||
@@ -188,7 +199,7 @@ export class Context {
|
||||
if (this._closeBrowserContextPromise)
|
||||
throw new Error('Another browser context is being closed.');
|
||||
// TODO: move to the browser context factory to make it based on isolation mode.
|
||||
const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal);
|
||||
const result = await this._browserContextFactory.createContext(this._clientVersion!, this._abortController.signal);
|
||||
const { browserContext } = result;
|
||||
await this._setupRequestInterception(browserContext);
|
||||
if (this.sessionLog)
|
||||
|
||||
@@ -18,11 +18,18 @@ import { z } from 'zod';
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { ManualPromise } from '../manualPromise.js';
|
||||
import { logUnhandledError } from '../log.js';
|
||||
|
||||
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
|
||||
export type ClientVersion = Implementation;
|
||||
export type ClientCapabilities = {
|
||||
roots?: {
|
||||
listRoots?: boolean
|
||||
};
|
||||
};
|
||||
|
||||
export type ToolResponse = {
|
||||
content: (TextContent | ImageContent)[];
|
||||
@@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise<ToolRespons
|
||||
export interface ServerBackend {
|
||||
name: string;
|
||||
version: string;
|
||||
initialize?(): Promise<void>;
|
||||
initialize?(server: Server): Promise<void>;
|
||||
tools(): ToolSchema<any>[];
|
||||
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
|
||||
serverInitialized?(version: ClientVersion | undefined): void;
|
||||
serverClosed?(): void;
|
||||
}
|
||||
|
||||
@@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend;
|
||||
|
||||
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
|
||||
const backend = serverBackendFactory();
|
||||
await backend.initialize?.();
|
||||
const server = createServer(backend, runHeartbeat);
|
||||
await server.connect(transport);
|
||||
}
|
||||
|
||||
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
|
||||
const initializedPromise = new ManualPromise<void>();
|
||||
const server = new Server({ name: backend.name, version: backend.version }, {
|
||||
capabilities: {
|
||||
tools: {},
|
||||
@@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
|
||||
|
||||
let heartbeatRunning = false;
|
||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||
await initializedPromise;
|
||||
|
||||
if (runHeartbeat && !heartbeatRunning) {
|
||||
heartbeatRunning = true;
|
||||
startHeartbeat(server);
|
||||
@@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
|
||||
return errorResult(String(error));
|
||||
}
|
||||
});
|
||||
|
||||
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
|
||||
addServerListener(server, 'initialized', () => {
|
||||
backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
|
||||
});
|
||||
addServerListener(server, 'close', () => backend.serverClosed?.());
|
||||
return server;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user