diff --git a/packages/browser/src/client/tester/context.ts b/packages/browser/src/client/tester/context.ts index 2e65d38420c6..cce848b5fbcb 100644 --- a/packages/browser/src/client/tester/context.ts +++ b/packages/browser/src/client/tester/context.ts @@ -1,5 +1,4 @@ import type { Options as TestingLibraryOptions, UserEvent as TestingLibraryUserEvent } from '@testing-library/user-event' -import type { BrowserRPC } from '@vitest/browser/client' import type { RunnerTask } from 'vitest' import type { BrowserPage, @@ -19,15 +18,11 @@ import { convertElementToCssSelector, ensureAwaited, getBrowserState, getWorkerS const state = () => getWorkerState() // @ts-expect-error not typed global const provider = __vitest_browser_runner__.provider -function filepath() { - return getWorkerState().filepath || getWorkerState().current?.file?.filepath || undefined -} -const rpc = () => getWorkerState().rpc as any as BrowserRPC const sessionId = getBrowserState().sessionId const channel = new BroadcastChannel(`vitest:${sessionId}`) function triggerCommand(command: string, ...args: any[]) { - return rpc().triggerCommand(sessionId, command, filepath(), args) + return getBrowserState().commands.triggerCommand(command, args) } export function createUserEvent(__tl_user_event_base__?: TestingLibraryUserEvent, options?: TestingLibraryOptions): UserEvent { @@ -52,6 +47,10 @@ export function createUserEvent(__tl_user_event_base__?: TestingLibraryUserEvent return createUserEvent() }, async cleanup() { + // avoid cleanup rpc call if there is nothing to cleanup + if (!keyboard.unreleased.length) { + return + } return ensureAwaited(async () => { await triggerCommand('__vitest_cleanup', keyboard) keyboard.unreleased = [] @@ -106,9 +105,7 @@ export function createUserEvent(__tl_user_event_base__?: TestingLibraryUserEvent }) }, tab(options: UserEventTabOptions = {}) { - return ensureAwaited(() => { - return triggerCommand('__vitest_tab', options) - }) + return ensureAwaited(() => triggerCommand('__vitest_tab', options)) }, async keyboard(text: string) { return ensureAwaited(async () => { diff --git a/packages/browser/src/client/tester/locators/index.ts b/packages/browser/src/client/tester/locators/index.ts index 4daf93f392d9..212b22038df4 100644 --- a/packages/browser/src/client/tester/locators/index.ts +++ b/packages/browser/src/client/tester/locators/index.ts @@ -1,4 +1,3 @@ -import type { BrowserRPC } from '@vitest/browser/client' import type { LocatorByRoleOptions, LocatorOptions, @@ -8,8 +7,6 @@ import type { UserEventFillOptions, UserEventHoverOptions, } from '@vitest/browser/context' -import type { WorkerGlobalState } from 'vitest' -import type { BrowserRunnerState } from '../../utils' import { page, server } from '@vitest/browser/context' import { getByAltTextSelector, @@ -22,7 +19,7 @@ import { Ivya, type ParsedSelector, } from 'ivya' -import { ensureAwaited, getBrowserState, getWorkerState } from '../../utils' +import { ensureAwaited, getBrowserState } from '../../utils' import { getElementError } from '../public-utils' // we prefer using playwright locators because they are more powerful and support Shadow DOM @@ -205,27 +202,10 @@ export abstract class Locator { return this.selector } - private get state(): BrowserRunnerState { - return getBrowserState() - } - - private get worker(): WorkerGlobalState { - return getWorkerState() - } - - private get rpc(): BrowserRPC { - return this.worker.rpc as any as BrowserRPC - } - protected triggerCommand(command: string, ...args: any[]): Promise { - const filepath = this.worker.filepath - || this.worker.current?.file?.filepath - || undefined - - return ensureAwaited(() => this.rpc.triggerCommand( - this.state.sessionId, + const commands = getBrowserState().commands + return ensureAwaited(() => commands.triggerCommand( command, - filepath, args, )) } diff --git a/packages/browser/src/client/tester/tester.ts b/packages/browser/src/client/tester/tester.ts index f6b667691d0e..0d981dff2098 100644 --- a/packages/browser/src/client/tester/tester.ts +++ b/packages/browser/src/client/tester/tester.ts @@ -1,7 +1,7 @@ import { channel, client, onCancel } from '@vitest/browser/client' -import { page, userEvent } from '@vitest/browser/context' +import { page, server, userEvent } from '@vitest/browser/context' import { collectTests, setupCommonEnv, SpyModule, startCoverageInsideWorker, startTests, stopCoverageInsideWorker } from 'vitest/browser' -import { executor, getBrowserState, getConfig, getWorkerState } from '../utils' +import { CommandsManager, executor, getBrowserState, getConfig, getWorkerState } from '../utils' import { setupDialogsSpy } from './dialog' import { setupExpectDom } from './expect-element' import { setupConsoleLogSpy } from './logger' @@ -34,6 +34,8 @@ async function prepareTestEnvironment(files: string[]) { state.onCancel = onCancel state.rpc = rpc as any + getBrowserState().commands = new CommandsManager() + // TODO: expose `worker` const interceptor = createModuleMockerInterceptor() const mocker = new VitestBrowserClientMocker( @@ -69,6 +71,8 @@ async function prepareTestEnvironment(files: string[]) { runner, config, state, + rpc, + commands: getBrowserState().commands, } } @@ -113,12 +117,34 @@ async function executeTests(method: 'run' | 'collect', files: string[]) { debug('runner resolved successfully') - const { config, runner, state } = preparedData + const { config, runner, state, commands, rpc } = preparedData state.durations.prepare = performance.now() - state.durations.prepare debug('prepare time', state.durations.prepare, 'ms') + let contextSwitched = false + + // webdiverio context depends on the iframe state, so we need to switch the context, + // we delay this in case the user doesn't use any userEvent commands to avoid the overhead + if (server.provider === 'webdriverio') { + let switchPromise: Promise | null = null + + commands.onCommand(async () => { + if (switchPromise) { + await switchPromise + } + // if this is the first command, make sure we switched the command context to an iframe + if (!contextSwitched) { + switchPromise = rpc.wdioSwitchContext('iframe').finally(() => { + switchPromise = null + contextSwitched = true + }) + await switchPromise + } + }) + } + try { await Promise.all([ setupCommonEnv(config), @@ -151,6 +177,9 @@ async function executeTests(method: 'run' | 'collect', files: string[]) { // need to cleanup for each tester // since playwright keyboard API is stateful on page instance level await userEvent.cleanup() + if (contextSwitched) { + await rpc.wdioSwitchContext('parent') + } } catch (error: any) { await client.rpc.onUnhandledError({ diff --git a/packages/browser/src/client/utils.ts b/packages/browser/src/client/utils.ts index 353ec5867f45..3b8aed043edb 100644 --- a/packages/browser/src/client/utils.ts +++ b/packages/browser/src/client/utils.ts @@ -1,4 +1,5 @@ import type { SerializedConfig, WorkerGlobalState } from 'vitest' +import type { BrowserRPC } from './client' export async function importId(id: string): Promise { const name = `/@id/${id}`.replace(/\\/g, '/') @@ -77,6 +78,7 @@ export interface BrowserRunnerState { method: 'run' | 'collect' runTests?: (tests: string[]) => Promise createTesters?: (files: string[]) => Promise + commands: CommandsManager cdp?: { on: (event: string, listener: (payload: any) => void) => void once: (event: string, listener: (payload: any) => void) => void @@ -194,3 +196,22 @@ function getParent(el: Element) { } return parent } + +export class CommandsManager { + private _listeners: ((command: string, args: any[]) => void)[] = [] + + public onCommand(listener: (command: string, args: any[]) => void): void { + this._listeners.push(listener) + } + + public async triggerCommand(command: string, args: any[]): Promise { + const state = getWorkerState() + const rpc = state.rpc as any as BrowserRPC + const { sessionId } = getBrowserState() + const filepath = state.filepath || state.current?.file?.filepath + if (this._listeners.length) { + await Promise.all(this._listeners.map(listener => listener(command, args))) + } + return rpc.triggerCommand(sessionId, command, filepath, args) + } +} diff --git a/packages/browser/src/node/commands/keyboard.ts b/packages/browser/src/node/commands/keyboard.ts index 34b2fcef579a..77428355ccd5 100644 --- a/packages/browser/src/node/commands/keyboard.ts +++ b/packages/browser/src/node/commands/keyboard.ts @@ -54,6 +54,9 @@ export const keyboardCleanup: UserEventCommand<(state: KeyboardState) => Promise state, ) => { const { provider, sessionId } = context + if (!state.unreleased) { + return + } if (provider instanceof PlaywrightBrowserProvider) { const page = provider.getPage(sessionId) for (const key of state.unreleased) { diff --git a/packages/browser/src/node/plugins/pluginContext.ts b/packages/browser/src/node/plugins/pluginContext.ts index 5929822ae557..f1547ab9bfd9 100644 --- a/packages/browser/src/node/plugins/pluginContext.ts +++ b/packages/browser/src/node/plugins/pluginContext.ts @@ -32,15 +32,13 @@ async function generateContextFile( globalServer: ParentBrowserProject, ) { const commands = Object.keys(globalServer.commands) - const filepathCode - = '__vitest_worker__.filepath || __vitest_worker__.current?.file?.filepath || undefined' const provider = [...globalServer.children][0].provider || { name: 'preview' } const providerName = provider.name const commandsCode = commands .filter(command => !command.startsWith('__vitest')) .map((command) => { - return ` ["${command}"]: (...args) => rpc().triggerCommand(sessionId, "${command}", filepath(), args),` + return ` ["${command}"]: (...args) => __vitest_browser_runner__.commands.triggerCommand("${command}", args),` }) .join('\n') @@ -53,9 +51,6 @@ async function generateContextFile( return ` import { page, createUserEvent, cdp } from '${distContextPath}' ${userEventNonProviderImport} -const filepath = () => ${filepathCode} -const rpc = () => __vitest_worker__.rpc -const sessionId = __vitest_browser_runner__.sessionId export const server = { platform: ${JSON.stringify(process.platform)}, diff --git a/packages/browser/src/node/providers/webdriver.ts b/packages/browser/src/node/providers/webdriver.ts index 9c4567532742..8ae1334b5f99 100644 --- a/packages/browser/src/node/providers/webdriver.ts +++ b/packages/browser/src/node/providers/webdriver.ts @@ -37,7 +37,7 @@ export class WebdriverBrowserProvider implements BrowserProvider { this.options = options as RemoteOptions } - async beforeCommand(): Promise { + async switchToTestFrame(): Promise { const page = this.browser! const iframe = await page.findElement( 'css selector', @@ -46,7 +46,7 @@ export class WebdriverBrowserProvider implements BrowserProvider { await page.switchToFrame(iframe) } - async afterCommand(): Promise { + async switchToMainFrame(): Promise { await this.browser!.switchToParentFrame() } diff --git a/packages/browser/src/node/rpc.ts b/packages/browser/src/node/rpc.ts index 5230d219f36f..aa631930bf1e 100644 --- a/packages/browser/src/node/rpc.ts +++ b/packages/browser/src/node/rpc.ts @@ -3,6 +3,7 @@ import type { ErrorWithDiff } from 'vitest' import type { BrowserCommandContext, ResolveSnapshotPathHandlerContext, TestProject } from 'vitest/node' import type { WebSocket } from 'ws' import type { ParentBrowserProject } from './projectParent' +import type { WebdriverBrowserProvider } from './providers/webdriver' import type { BrowserServerState } from './state' import type { WebSocketBrowserEvents, WebSocketBrowserHandlers } from './types' import { existsSync, promises as fs } from 'node:fs' @@ -203,6 +204,21 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject): void { getCountOfFailedTests() { return vitest.state.getCountOfFailedTests() }, + async wdioSwitchContext(direction) { + const provider = project.browser!.provider as WebdriverBrowserProvider + if (!provider) { + throw new Error('Commands are only available for browser tests.') + } + if (provider.name !== 'webdriverio') { + throw new Error('Switch context is only available for WebDriverIO provider.') + } + if (direction === 'iframe') { + await provider.switchToTestFrame() + } + else { + await provider.switchToMainFrame() + } + }, async triggerCommand(sessionId, command, testPath, payload) { debug?.('[%s] Triggering command "%s"', sessionId, command) const provider = project.browser!.provider @@ -213,7 +229,6 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject): void { if (!commands || !commands[command]) { throw new Error(`Unknown command "${command}".`) } - await provider.beforeCommand?.(command, payload) const context = Object.assign( { testPath, @@ -224,14 +239,7 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject): void { }, provider.getCommandsContext(sessionId), ) as any as BrowserCommandContext - let result - try { - result = await commands[command](context, ...payload) - } - finally { - await provider.afterCommand?.(command, payload) - } - return result + return await commands[command](context, ...payload) }, finishBrowserTests(sessionId: string) { debug?.('[%s] Finishing browser tests for session', sessionId) diff --git a/packages/browser/src/node/types.ts b/packages/browser/src/node/types.ts index 9095aa7964a7..53637d71429f 100644 --- a/packages/browser/src/node/types.ts +++ b/packages/browser/src/node/types.ts @@ -39,6 +39,7 @@ export interface WebSocketBrowserHandlers { getBrowserFileSourceMap: ( id: string ) => SourceMap | null | { mappings: '' } | undefined + wdioSwitchContext: (direction: 'iframe' | 'parent') => void // cdp sendCdpEvent: (sessionId: string, event: string, payload?: Record) => unknown diff --git a/packages/vitest/src/node/types/browser.ts b/packages/vitest/src/node/types/browser.ts index 3c51fbd385bd..8b3fc22d9668 100644 --- a/packages/vitest/src/node/types/browser.ts +++ b/packages/vitest/src/node/types/browser.ts @@ -24,8 +24,6 @@ export interface BrowserProvider { */ supportsParallelism: boolean getSupportedBrowsers: () => readonly string[] - beforeCommand?: (command: string, args: unknown[]) => Awaitable - afterCommand?: (command: string, args: unknown[]) => Awaitable getCommandsContext: (sessionId: string) => Record openPage: (sessionId: string, url: string, beforeNavigate?: () => Promise) => Promise getCDPSession?: (sessionId: string) => Promise