feat(dialog): handle dialogs (#212)

This commit is contained in:
Pavel Feldman
2025-04-17 14:03:13 -07:00
committed by GitHub
parent 4b261286bf
commit 6481100bdf
10 changed files with 461 additions and 15 deletions

View File

@@ -18,9 +18,10 @@ import * as playwright from 'playwright';
import yaml from 'yaml';
import { waitForCompletion } from './tools/utils';
import { ManualPromise } from './manualPromise';
import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types';
import type { ModalState, Tool } from './tools/tool';
import type { ModalState, Tool, ToolActionResult } from './tools/tool';
export type ContextOptions = {
browserName?: 'chromium' | 'firefox' | 'webkit';
@@ -32,6 +33,10 @@ export type ContextOptions = {
type PageOrFrameLocator = playwright.Page | playwright.FrameLocator;
type PendingAction = {
dialogShown: ManualPromise<void>;
};
export class Context {
readonly tools: Tool[];
readonly options: ContextOptions;
@@ -40,6 +45,7 @@ export class Context {
private _tabs: Tab[] = [];
private _currentTab: Tab | undefined;
private _modalStates: (ModalState & { tab: Tab })[] = [];
private _pendingAction: PendingAction | undefined;
constructor(tools: Tool[], options: ContextOptions) {
this.tools = tools;
@@ -120,6 +126,7 @@ export class Context {
// Tab management is done outside of the action() call.
const toolResult = await tool.handle(this, params);
const { code, action, waitForNetwork, captureSnapshot, resultOverride } = toolResult;
const racingAction = action ? () => this._raceAgainstModalDialogs(action) : undefined;
if (resultOverride)
return resultOverride;
@@ -138,11 +145,11 @@ export class Context {
let actionResult: { content?: (ImageContent | TextContent)[] } | undefined;
try {
if (waitForNetwork)
actionResult = await waitForCompletion(tab.page, async () => action?.()) ?? undefined;
actionResult = await waitForCompletion(this, tab.page, async () => racingAction?.()) ?? undefined;
else
actionResult = await action?.() ?? undefined;
actionResult = await racingAction?.() ?? undefined;
} finally {
if (captureSnapshot)
if (captureSnapshot && !this._javaScriptBlocked())
await tab.captureSnapshot();
}
@@ -190,6 +197,43 @@ ${code.join('\n')}
};
}
async waitForTimeout(time: number) {
if (this._currentTab && !this._javaScriptBlocked())
await this._currentTab.page.waitForTimeout(time);
else
await new Promise(f => setTimeout(f, time));
}
private async _raceAgainstModalDialogs(action: () => Promise<ToolActionResult>): Promise<ToolActionResult> {
this._pendingAction = {
dialogShown: new ManualPromise(),
};
let result: ToolActionResult | undefined;
try {
await Promise.race([
action().then(r => result = r),
this._pendingAction.dialogShown,
]);
} finally {
this._pendingAction = undefined;
}
return result;
}
private _javaScriptBlocked(): boolean {
return this._modalStates.some(state => state.type === 'dialog');
}
dialogShown(tab: Tab, dialog: playwright.Dialog) {
this.setModalState({
type: 'dialog',
description: `"${dialog.type()}" dialog with message "${dialog.message()}"`,
dialog,
}, tab);
this._pendingAction?.dialogShown.resolve();
}
private _onPageCreated(page: playwright.Page) {
const tab = new Tab(this, page, tab => this._onPageClosed(tab));
this._tabs.push(tab);
@@ -293,6 +337,7 @@ export class Tab {
fileChooser: chooser,
}, this);
});
page.on('dialog', dialog => this.context.dialogShown(this, dialog));
page.setDefaultNavigationTimeout(60000);
page.setDefaultTimeout(5000);
}

View File

@@ -21,6 +21,7 @@ import fs from 'fs';
import { createServerWithTools } from './server';
import common from './tools/common';
import console from './tools/console';
import dialogs from './tools/dialogs';
import files from './tools/files';
import install from './tools/install';
import keyboard from './tools/keyboard';
@@ -37,6 +38,7 @@ import type { LaunchOptions } from 'playwright';
const snapshotTools: Tool[] = [
...common(true),
...console,
...dialogs(true),
...files(true),
...install,
...keyboard(true),
@@ -49,6 +51,7 @@ const snapshotTools: Tool[] = [
const screenshotTools: Tool[] = [
...common(false),
...console,
...dialogs(false),
...files(false),
...install,
...keyboard(false),

127
src/manualPromise.ts Normal file
View File

@@ -0,0 +1,127 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export class ManualPromise<T = void> extends Promise<T> {
private _resolve!: (t: T) => void;
private _reject!: (e: Error) => void;
private _isDone: boolean;
constructor() {
let resolve: (t: T) => void;
let reject: (e: Error) => void;
super((f, r) => {
resolve = f;
reject = r;
});
this._isDone = false;
this._resolve = resolve!;
this._reject = reject!;
}
isDone() {
return this._isDone;
}
resolve(t: T) {
this._isDone = true;
this._resolve(t);
}
reject(e: Error) {
this._isDone = true;
this._reject(e);
}
static override get [Symbol.species]() {
return Promise;
}
override get [Symbol.toStringTag]() {
return 'ManualPromise';
}
}
export class LongStandingScope {
private _terminateError: Error | undefined;
private _closeError: Error | undefined;
private _terminatePromises = new Map<ManualPromise<Error>, string[]>();
private _isClosed = false;
reject(error: Error) {
this._isClosed = true;
this._terminateError = error;
for (const p of this._terminatePromises.keys())
p.resolve(error);
}
close(error: Error) {
this._isClosed = true;
this._closeError = error;
for (const [p, frames] of this._terminatePromises)
p.resolve(cloneError(error, frames));
}
isClosed() {
return this._isClosed;
}
static async raceMultiple<T>(scopes: LongStandingScope[], promise: Promise<T>): Promise<T> {
return Promise.race(scopes.map(s => s.race(promise)));
}
async race<T>(promise: Promise<T> | Promise<T>[]): Promise<T> {
return this._race(Array.isArray(promise) ? promise : [promise], false) as Promise<T>;
}
async safeRace<T>(promise: Promise<T>, defaultValue?: T): Promise<T> {
return this._race([promise], true, defaultValue);
}
private async _race(promises: Promise<any>[], safe: boolean, defaultValue?: any): Promise<any> {
const terminatePromise = new ManualPromise<Error>();
const frames = captureRawStack();
if (this._terminateError)
terminatePromise.resolve(this._terminateError);
if (this._closeError)
terminatePromise.resolve(cloneError(this._closeError, frames));
this._terminatePromises.set(terminatePromise, frames);
try {
return await Promise.race([
terminatePromise.then(e => safe ? defaultValue : Promise.reject(e)),
...promises
]);
} finally {
this._terminatePromises.delete(terminatePromise);
}
}
}
function cloneError(error: Error, frames: string[]) {
const clone = new Error();
clone.name = error.name;
clone.message = error.message;
clone.stack = [error.name + ':' + error.message, ...frames].join('\n');
return clone;
}
function captureRawStack(): string[] {
const stackTraceLimit = Error.stackTraceLimit;
Error.stackTraceLimit = 50;
const error = new Error();
const stack = error.stack || '';
Error.stackTraceLimit = stackTraceLimit;
return stack.split('\n');
}

65
src/tools/dialogs.ts Normal file
View File

@@ -0,0 +1,65 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { z } from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import type { ToolFactory } from './tool';
const handleDialogSchema = z.object({
accept: z.boolean().describe('Whether to accept the dialog.'),
promptText: z.string().optional().describe('The text of the prompt in case of a prompt dialog.'),
});
const handleDialog: ToolFactory = captureSnapshot => ({
capability: 'core',
schema: {
name: 'browser_handle_dialog',
description: 'Handle a dialog',
inputSchema: zodToJsonSchema(handleDialogSchema),
},
handle: async (context, params) => {
const validatedParams = handleDialogSchema.parse(params);
const dialogState = context.modalStates().find(state => state.type === 'dialog');
if (!dialogState)
throw new Error('No dialog visible');
if (validatedParams.accept)
await dialogState.dialog.accept(validatedParams.promptText);
else
await dialogState.dialog.dismiss();
context.clearModalState(dialogState);
const code = [
`// <internal code to handle "${dialogState.dialog.type()}" dialog>`,
];
return {
code,
captureSnapshot,
waitForNetwork: false,
};
},
clearsModalState: 'dialog',
});
export default (captureSnapshot: boolean) => [
handleDialog(captureSnapshot),
];

View File

@@ -32,14 +32,22 @@ export type FileUploadModalState = {
fileChooser: playwright.FileChooser;
};
export type ModalState = FileUploadModalState;
export type DialogModalState = {
type: 'dialog';
description: string;
dialog: playwright.Dialog;
};
export type ModalState = FileUploadModalState | DialogModalState;
export type ToolActionResult = { content?: (ImageContent | TextContent)[] } | undefined | void;
export type ToolResult = {
code: string[];
action?: () => Promise<{ content?: (ImageContent | TextContent)[] } | undefined | void>;
action?: () => Promise<ToolActionResult>;
captureSnapshot: boolean;
waitForNetwork: boolean;
resultOverride?: { content?: (ImageContent | TextContent)[] };
resultOverride?: ToolActionResult;
};
export type Tool = {

View File

@@ -15,8 +15,9 @@
*/
import type * as playwright from 'playwright';
import type { Context } from '../context';
export async function waitForCompletion<R>(page: playwright.Page, callback: () => Promise<R>): Promise<R> {
export async function waitForCompletion<R>(context: Context, page: playwright.Page, callback: () => Promise<R>): Promise<R> {
const requests = new Set<playwright.Request>();
let frameNavigated = false;
let waitCallback: () => void = () => {};
@@ -62,7 +63,7 @@ export async function waitForCompletion<R>(page: playwright.Page, callback: () =
if (!requests.size && !frameNavigated)
waitCallback();
await waitBarrier;
await page.evaluate(() => new Promise(f => setTimeout(f, 1000)));
await context.waitForTimeout(1000);
return result;
} finally {
dispose();