chore: refactor initialize (#812)
This commit is contained in:
@@ -43,9 +43,16 @@ export class BrowserServerBackend implements ServerBackend {
|
|||||||
this._tools = filteredTools(config);
|
this._tools = filteredTools(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize() {
|
async initialize(server: mcpServer.Server): Promise<void> {
|
||||||
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined;
|
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config) : undefined;
|
||||||
this._context = new Context(this._tools, this._config, this._browserContextFactory, this._sessionLog);
|
this._context = new Context({
|
||||||
|
tools: this._tools,
|
||||||
|
config: this._config,
|
||||||
|
browserContextFactory: this._browserContextFactory,
|
||||||
|
sessionLog: this._sessionLog,
|
||||||
|
clientVersion: server.getClientVersion(),
|
||||||
|
capabilities: server.getClientCapabilities() as mcpServer.ClientCapabilities,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
tools(): mcpServer.ToolSchema<any>[] {
|
tools(): mcpServer.ToolSchema<any>[] {
|
||||||
@@ -69,10 +76,6 @@ export class BrowserServerBackend implements ServerBackend {
|
|||||||
return response.serialize();
|
return response.serialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInitialized(version: mcpServer.ClientVersion | undefined) {
|
|
||||||
this._context!.clientVersion = version;
|
|
||||||
}
|
|
||||||
|
|
||||||
serverClosed() {
|
serverClosed() {
|
||||||
void this._context!.dispose().catch(logUnhandledError);
|
void this._context!.dispose().catch(logUnhandledError);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import * as playwright from 'playwright';
|
|||||||
import { logUnhandledError } from './log.js';
|
import { logUnhandledError } from './log.js';
|
||||||
import { Tab } from './tab.js';
|
import { Tab } from './tab.js';
|
||||||
|
|
||||||
|
import type * as mcpServer from './mcp/server.js';
|
||||||
import type { Tool } from './tools/tool.js';
|
import type { Tool } from './tools/tool.js';
|
||||||
import type { FullConfig } from './config.js';
|
import type { FullConfig } from './config.js';
|
||||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||||
@@ -28,6 +29,14 @@ import type { SessionLog } from './sessionLog.js';
|
|||||||
|
|
||||||
const testDebug = debug('pw:mcp:test');
|
const testDebug = debug('pw:mcp:test');
|
||||||
|
|
||||||
|
type ContextOptions = {
|
||||||
|
tools: Tool[];
|
||||||
|
config: FullConfig;
|
||||||
|
browserContextFactory: BrowserContextFactory;
|
||||||
|
sessionLog: SessionLog | undefined;
|
||||||
|
clientVersion: { name: string; version: string; } | undefined;
|
||||||
|
capabilities: mcpServer.ClientCapabilities | undefined;
|
||||||
|
};
|
||||||
export class Context {
|
export class Context {
|
||||||
readonly tools: Tool[];
|
readonly tools: Tool[];
|
||||||
readonly config: FullConfig;
|
readonly config: FullConfig;
|
||||||
@@ -36,19 +45,21 @@ export class Context {
|
|||||||
private _browserContextFactory: BrowserContextFactory;
|
private _browserContextFactory: BrowserContextFactory;
|
||||||
private _tabs: Tab[] = [];
|
private _tabs: Tab[] = [];
|
||||||
private _currentTab: Tab | undefined;
|
private _currentTab: Tab | undefined;
|
||||||
|
private _clientVersion: { name: string; version: string; } | undefined;
|
||||||
clientVersion: { name: string; version: string; } | undefined;
|
private _clientCapabilities: mcpServer.ClientCapabilities;
|
||||||
|
|
||||||
private static _allContexts: Set<Context> = new Set();
|
private static _allContexts: Set<Context> = new Set();
|
||||||
private _closeBrowserContextPromise: Promise<void> | undefined;
|
private _closeBrowserContextPromise: Promise<void> | undefined;
|
||||||
private _isRunningTool: boolean = false;
|
private _isRunningTool: boolean = false;
|
||||||
private _abortController = new AbortController();
|
private _abortController = new AbortController();
|
||||||
|
|
||||||
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) {
|
constructor(options: ContextOptions) {
|
||||||
this.tools = tools;
|
this.tools = options.tools;
|
||||||
this.config = config;
|
this.config = options.config;
|
||||||
this._browserContextFactory = browserContextFactory;
|
this.sessionLog = options.sessionLog;
|
||||||
this.sessionLog = sessionLog;
|
this._browserContextFactory = options.browserContextFactory;
|
||||||
|
this._clientVersion = options.clientVersion;
|
||||||
|
this._clientCapabilities = options.capabilities || {};
|
||||||
testDebug('create context');
|
testDebug('create context');
|
||||||
Context._allContexts.add(this);
|
Context._allContexts.add(this);
|
||||||
}
|
}
|
||||||
@@ -188,7 +199,7 @@ export class Context {
|
|||||||
if (this._closeBrowserContextPromise)
|
if (this._closeBrowserContextPromise)
|
||||||
throw new Error('Another browser context is being closed.');
|
throw new Error('Another browser context is being closed.');
|
||||||
// TODO: move to the browser context factory to make it based on isolation mode.
|
// TODO: move to the browser context factory to make it based on isolation mode.
|
||||||
const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal);
|
const result = await this._browserContextFactory.createContext(this._clientVersion!, this._abortController.signal);
|
||||||
const { browserContext } = result;
|
const { browserContext } = result;
|
||||||
await this._setupRequestInterception(browserContext);
|
await this._setupRequestInterception(browserContext);
|
||||||
if (this.sessionLog)
|
if (this.sessionLog)
|
||||||
|
|||||||
@@ -18,11 +18,18 @@ import { z } from 'zod';
|
|||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||||
|
import { ManualPromise } from '../manualPromise.js';
|
||||||
|
import { logUnhandledError } from '../log.js';
|
||||||
|
|
||||||
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
|
import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||||
|
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
|
|
||||||
export type ClientVersion = Implementation;
|
export type ClientCapabilities = {
|
||||||
|
roots?: {
|
||||||
|
listRoots?: boolean
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export type ToolResponse = {
|
export type ToolResponse = {
|
||||||
content: (TextContent | ImageContent)[];
|
content: (TextContent | ImageContent)[];
|
||||||
@@ -42,10 +49,9 @@ export type ToolHandler = (toolName: string, params: any) => Promise<ToolRespons
|
|||||||
export interface ServerBackend {
|
export interface ServerBackend {
|
||||||
name: string;
|
name: string;
|
||||||
version: string;
|
version: string;
|
||||||
initialize?(): Promise<void>;
|
initialize?(server: Server): Promise<void>;
|
||||||
tools(): ToolSchema<any>[];
|
tools(): ToolSchema<any>[];
|
||||||
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
|
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
|
||||||
serverInitialized?(version: ClientVersion | undefined): void;
|
|
||||||
serverClosed?(): void;
|
serverClosed?(): void;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,12 +59,12 @@ export type ServerBackendFactory = () => ServerBackend;
|
|||||||
|
|
||||||
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
|
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
|
||||||
const backend = serverBackendFactory();
|
const backend = serverBackendFactory();
|
||||||
await backend.initialize?.();
|
|
||||||
const server = createServer(backend, runHeartbeat);
|
const server = createServer(backend, runHeartbeat);
|
||||||
await server.connect(transport);
|
await server.connect(transport);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
|
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
|
||||||
|
const initializedPromise = new ManualPromise<void>();
|
||||||
const server = new Server({ name: backend.name, version: backend.version }, {
|
const server = new Server({ name: backend.name, version: backend.version }, {
|
||||||
capabilities: {
|
capabilities: {
|
||||||
tools: {},
|
tools: {},
|
||||||
@@ -82,6 +88,8 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
|
|||||||
|
|
||||||
let heartbeatRunning = false;
|
let heartbeatRunning = false;
|
||||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||||
|
await initializedPromise;
|
||||||
|
|
||||||
if (runHeartbeat && !heartbeatRunning) {
|
if (runHeartbeat && !heartbeatRunning) {
|
||||||
heartbeatRunning = true;
|
heartbeatRunning = true;
|
||||||
startHeartbeat(server);
|
startHeartbeat(server);
|
||||||
@@ -101,8 +109,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
|
|||||||
return errorResult(String(error));
|
return errorResult(String(error));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
addServerListener(server, 'initialized', () => {
|
||||||
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
|
backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
|
||||||
|
});
|
||||||
addServerListener(server, 'close', () => backend.serverClosed?.());
|
addServerListener(server, 'close', () => backend.serverClosed?.());
|
||||||
return server;
|
return server;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user