From 1fcfd71ddd9e269b24027cf42008fab56c73612f Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 6 Jun 2026 16:45:49 +0100 Subject: [PATCH] [py] Add expect_* context managers for BiDi events Adds a Subscription primitive to the BiDi event manager that registers its event handler at creation time, so an event fired by an action inside the with block cannot be missed: with driver.network.expect_response("**/api/**") as response_info: driver.find_element(By.ID, "load").click() response = response_info.value New high-level methods generated via the enhancements manifest: - network.expect_request / network.expect_response (URL glob or predicate filtering, observational Request/Response wrappers) - script.expect_console_message (predicate over ConsoleMessage) - browsing_context.expect_user_prompt (typed prompt params) - browsing_context.expect_download, correlating downloadWillBegin and downloadEnd by navigation id into a new Download object with path(), save_as() and failure() The event manager gains a raw=True option so expect_* handlers receive the full wire-level params instead of the generated dataclasses that drop fields like request/response. Includes 25 unit tests and 6 browser-based integration tests. --- py/generate_bidi.py | 3 +- py/private/_event_manager.py | 180 ++++++- py/private/_network_handlers.py | 15 + py/private/bidi_enhancements_manifest.py | 223 +++++++++ .../common/bidi_browsing_context_tests.py | 38 ++ .../webdriver/common/bidi_network_tests.py | 18 + .../webdriver/common/bidi_script_tests.py | 18 + .../common/bidi_expect_events_tests.py | 464 ++++++++++++++++++ 8 files changed, 955 insertions(+), 4 deletions(-) create mode 100644 py/test/unit/selenium/webdriver/common/bidi_expect_events_tests.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 9f88c720dd2b5..9c114812804ba 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -657,7 +657,8 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: local_imports.append("from selenium.webdriver.common.bidi.common import command_builder") if self.events: local_imports.append( - "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" + "from selenium.webdriver.common.bidi._event_manager import " + "EventConfig, Subscription, _EventWrapper, _EventManager" ) code += "\n".join(stdlib_imports) + "\n" diff --git a/py/private/_event_manager.py b/py/private/_event_manager.py index 1dcc8288ce683..410478216a655 100644 --- a/py/private/_event_manager.py +++ b/py/private/_event_manager.py @@ -25,13 +25,18 @@ from __future__ import annotations +import logging +import queue import threading from collections.abc import Callable from dataclasses import dataclass from typing import Any +from selenium.common.exceptions import TimeoutException from selenium.webdriver.common.bidi.session import Session +logger = logging.getLogger(__name__) + @dataclass class EventConfig: @@ -86,6 +91,133 @@ def _camel_to_snake(name: str) -> str: return "".join(result) +_UNSET = object() + + +class Subscription: + """A pending expectation for a single BiDi event. + + The event handler is registered when the subscription is created, so an + event fired by an action inside a ``with`` block is captured without a + race between the action and the wait:: + + with driver.network.expect_response("**/api/**") as response_info: + driver.find_element(By.ID, "load").click() + response = response_info.value + + Exiting the ``with`` block waits for a matching event (raising + :class:`~selenium.common.exceptions.TimeoutException` if none arrives + within ``timeout``) and then removes the event handler. Outside a + ``with`` block, call :meth:`wait` (or read :attr:`value`) to block until + the event arrives, and :meth:`cancel` to stop listening without waiting. + + Exactly the first matching event is captured; matches arriving after the + subscription detaches are silently discarded. + """ + + def __init__( + self, + register: Callable[[Callable], Any], + unregister: Callable[[Any], None], + predicate: Callable[[Any], bool] | None = None, + timeout: float = 30.0, + transform: Callable[[Any], Any] | None = None, + description: str = "event", + ): + self._unregister = unregister + self._predicate = predicate + self._timeout = timeout + self._transform = transform + self._description = description + self._events: queue.SimpleQueue = queue.SimpleQueue() + self._value = _UNSET + self._detached = False + self._detach_lock = threading.Lock() + self._cleanups: list[Callable[[], None]] = [] + self._token = register(self._on_event) + + def _on_event(self, params: Any) -> None: + if self._detached: + return + try: + value = self._transform(params) if self._transform else params + if self._predicate is None or self._predicate(value): + self._events.put(value) + except Exception: + logger.exception("Predicate or transform for %s raised; event dropped", self._description) + + def wait(self, timeout: float | None = None) -> Any: + """Block until a matching event arrives and return it. + + The first matching event is cached: later calls (and :attr:`value`) + return it without waiting again. The event handler is removed once a + match is captured; on timeout it stays registered so the wait can be + retried — call :meth:`cancel` to stop listening early. + + Args: + timeout: Seconds to wait; defaults to the subscription's timeout. + + Raises: + TimeoutException: If no matching event arrives in time. + """ + if self._value is not _UNSET: + return self._value + timeout = self._timeout if timeout is None else timeout + try: + self._value = self._events.get(timeout=timeout) + except queue.Empty: + raise TimeoutException(f"Timed out after {timeout}s waiting for {self._description}") from None + self._detach() + return self._value + + @property + def value(self) -> Any: + """The captured event, waiting for it first if necessary.""" + return self.wait() + + def cancel(self) -> None: + """Stop listening without waiting for an event.""" + self._detach() + + def add_cleanup(self, cleanup: Callable[[], None]) -> None: + """Run ``cleanup`` when the subscription detaches. + + Used by ``expect_*`` helpers that register companion event handlers + (e.g. ``expect_download``) so those are removed alongside this one. + """ + self._cleanups.append(cleanup) + + def _detach(self) -> None: + with self._detach_lock: + if self._detached: + return + self._detached = True + # Unregister and run cleanups outside the lock: they perform BiDi + # I/O and must not block a concurrent detach attempt. + try: + self._unregister(self._token) + except Exception: + logger.exception("Failed to remove event handler for %s", self._description) + for cleanup in self._cleanups: + try: + cleanup() + except Exception: + logger.exception("Subscription cleanup for %s failed", self._description) + + def __enter__(self) -> Subscription: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + if exc_type is None: + try: + self.wait() + finally: + self._detach() + else: + self._detach() + return False + + class _EventManager: """Manages event subscriptions and callbacks.""" @@ -94,6 +226,7 @@ def __init__(self, conn, event_configs: dict[str, EventConfig]): self.event_configs = event_configs self.subscriptions: dict = {} self._event_wrappers = {} # Cache of _EventWrapper objects + self._raw_wrappers = {} # Cache of raw-dict _EventWrapper objects self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) self._subscription_lock = threading.Lock() @@ -144,10 +277,20 @@ def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> No if entry and callback_id in entry["callbacks"]: entry["callbacks"].remove(callback_id) - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + def add_event_handler( + self, event: str, callback: Callable, contexts: list[str] | None = None, raw: bool = False + ) -> int: event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) + # Use the event wrapper for add_callback. Raw handlers receive the + # unfiltered wire-level params dict instead of the typed dataclass, + # which may not carry every event field. + if raw: + event_wrapper = self._raw_wrappers.get(event_config.bidi_event) + if event_wrapper is None: + event_wrapper = _EventWrapper(event_config.bidi_event, dict) + self._raw_wrappers[event_config.bidi_event] = event_wrapper + else: + event_wrapper = self._event_wrappers.get(event_config.bidi_event) callback_id = self.conn.add_callback(event_wrapper, callback) self.subscribe_to_event(event_config.bidi_event, contexts) self.add_callback_to_tracking(event_config.bidi_event, callback_id) @@ -160,6 +303,37 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: self.remove_callback_from_tracking(event_config.bidi_event, callback_id) self.unsubscribe_from_event(event_config.bidi_event) + def expect( + self, + event: str, + predicate: Callable[[Any], bool] | None = None, + timeout: float = 30.0, + transform: Callable[[Any], Any] | None = None, + raw: bool = False, + contexts: list[str] | None = None, + ) -> Subscription: + """Return a :class:`Subscription` capturing the next matching ``event``. + + Args: + event: The event key to subscribe to. + predicate: Optional filter applied to each (transformed) event; + the first event for which it returns true is captured. + timeout: Default seconds the subscription waits for a match. + transform: Optional conversion applied to the event payload + before the predicate sees it and before it is returned. + raw: When true the handler receives the wire-level params dict + instead of the typed event dataclass. + contexts: Optional browsing context IDs to subscribe to. + """ + return Subscription( + register=lambda callback: self.add_event_handler(event, callback, contexts, raw=raw), + unregister=lambda callback_id: self.remove_event_handler(event, callback_id), + predicate=predicate, + timeout=timeout, + transform=transform, + description=f"event '{event}'", + ) + def clear_event_handlers(self) -> None: """Clear all event handlers.""" with self._subscription_lock: diff --git a/py/private/_network_handlers.py b/py/private/_network_handlers.py index 7eaafdb0a4c5d..2e68938fc5590 100644 --- a/py/private/_network_handlers.py +++ b/py/private/_network_handlers.py @@ -264,6 +264,21 @@ def globs_to_url_patterns(patterns: list | None) -> list[dict] | None: return translated or None +def to_url_predicate(url_or_predicate) -> Callable | None: + """Normalize an ``expect_*`` filter into a predicate over wrapped events. + + A string is treated as a URL glob (``*``, ``**``, ``?``) matched against + the event's ``url`` attribute; a callable is returned unchanged; ``None`` + matches everything. + """ + if url_or_predicate is None: + return None + if callable(url_or_predicate): + return url_or_predicate + regex = glob_to_regex(str(url_or_predicate)) + return lambda event: bool(regex.match(event.url or "")) + + class Request: """Wraps a BiDi network request event and provides request action methods. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 5abbae30b8e8a..6212da8c18e7c 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -268,6 +268,97 @@ class SetClientWindowStateParameters: cmd = command_builder("browsingContext.setViewport", params) result = self._conn.execute(cmd) return result''', + ''' def expect_user_prompt(self, predicate=None, timeout=30, contexts=None): + """Return a subscription capturing the next matching user prompt. + + The handler is registered immediately, so a prompt opened by an + action inside the ``with`` block cannot be missed:: + + with driver.browsing_context.expect_user_prompt() as prompt_info: + driver.find_element(By.ID, "delete").click() + prompt = prompt_info.value + driver.browsing_context.handle_user_prompt(prompt.context, accept=True) + + Args: + predicate: Optional filter over the + :class:`UserPromptOpenedParameters`; the first prompt for + which it returns true is captured. ``None`` matches every + prompt. + timeout: Seconds to wait for a match before raising + ``TimeoutException``. + contexts: Optional browsing context IDs to subscribe to. + + Returns: + A :class:`Subscription` whose ``value`` is the matching + :class:`UserPromptOpenedParameters`. + """ + return self._event_manager.expect( + "user_prompt_opened", + predicate=predicate, + timeout=timeout, + contexts=contexts, + )''', + ''' def expect_download(self, timeout=30, contexts=None): + """Return a subscription capturing the next finished download. + + The handlers are registered immediately, so a download triggered by + an action inside the ``with`` block cannot be missed. The + ``browsingContext.downloadWillBegin`` and ``downloadEnd`` events are + correlated by navigation ID into a single :class:`Download`:: + + with driver.browsing_context.expect_download() as download_info: + driver.find_element(By.ID, "export").click() + download = download_info.value + download.save_as("/tmp/" + download.suggested_filename) + + Args: + timeout: Seconds to wait for the download to finish before + raising ``TimeoutException``. + contexts: Optional browsing context IDs to subscribe to. + + Returns: + A :class:`Subscription` whose ``value`` is the finished + :class:`Download`. + """ + suggested_filenames = {} + + def _record_begin(params): + if isinstance(params, dict): + suggested_filenames[params.get("navigation")] = params.get("suggestedFilename") + + begin_callback_id = self._event_manager.add_event_handler( + "download_will_begin", _record_begin, contexts, raw=True + ) + + def _assemble(params): + # downloadWillBegin precedes downloadEnd for the same navigation, + # so the suggested filename is recorded by the time we assemble. + params = params if isinstance(params, dict) else {} + navigation = params.get("navigation") + return Download( + url=params.get("url"), + suggested_filename=suggested_filenames.get(navigation), + filepath=params.get("filepath"), + status=params.get("status"), + context=params.get("context"), + navigation=navigation, + ) + + try: + subscription = self._event_manager.expect( + "download_end", + timeout=timeout, + transform=_assemble, + raw=True, + contexts=contexts, + ) + except Exception: + self._event_manager.remove_event_handler("download_will_begin", begin_callback_id) + raise + subscription.add_cleanup( + lambda: self._event_manager.remove_event_handler("download_will_begin", begin_callback_id) + ) + return subscription''', ], # Non-CDDL download event dataclasses (Chromium-specific) "extra_dataclasses": [ @@ -309,6 +400,40 @@ def from_json(cls, params: dict) -> DownloadEndParams: filepath=params.get("filepath"), ) return cls(download_params=dp)''', + '''@dataclass +class Download: + """A finished browser download, assembled from the + ``browsingContext.downloadWillBegin`` and ``downloadEnd`` events.""" + + url: str | None = None + suggested_filename: str | None = None + filepath: str | None = None + status: str | None = None + context: Any | None = None + navigation: Any | None = None + + def path(self): + """The on-disk path of the downloaded file, or ``None``.""" + from pathlib import Path + + return Path(self.filepath) if self.filepath else None + + def save_as(self, destination): + """Copy the downloaded file to ``destination`` and return the new path. + + Raises: + ValueError: If the download produced no file on disk (for + example when it was canceled). + """ + if not self.filepath: + raise ValueError(f"Download has no file on disk (status={self.status!r})") + import shutil + + return shutil.copy(self.filepath, destination) + + def failure(self) -> str | None: + """``None`` when the download completed, otherwise the status.""" + return None if self.status == "complete" else self.status''', ], # Download events are now in the CDDL spec, so no extra_events needed }, @@ -1014,6 +1139,36 @@ def __init__(self2, realm, origin, type_, context): ''' def clear_dom_mutation_handlers(self) -> None: """Remove all DOM mutation handlers.""" self._dom_mutation_handlers.clear_handlers()''', + ''' def expect_console_message(self, predicate=None, timeout=30): + """Return a subscription capturing the next matching console message. + + The handler is registered immediately, so a message logged by an + action inside the ``with`` block cannot be missed:: + + with driver.script.expect_console_message( + lambda msg: msg.level == "error" + ) as message_info: + driver.find_element(By.ID, "trigger").click() + message = message_info.value + + Args: + predicate: Optional filter over the :class:`ConsoleMessage`; the + first message for which it returns true is captured. ``None`` + matches every console message. + timeout: Seconds to wait for a match before raising + ``TimeoutException``. + + Returns: + A :class:`Subscription` whose ``value`` is the matching + :class:`ConsoleMessage`. + """ + return Subscription( + register=lambda callback: self._log_handlers.add_handler(callback, LogHandlerRegistry.CONSOLE), + unregister=self._log_handlers.remove_handler, + predicate=predicate, + timeout=timeout, + description="console message", + )''', ], }, "network": { @@ -1069,6 +1224,7 @@ def to_bidi_dict(self) -> dict: Response, ResponseHandlerRegistry, looks_like_url_glob, + to_url_predicate, )""", ], # Override auth_required to use raw dict so _auth_callback receives all @@ -1412,6 +1568,73 @@ def _auth_callback(params): intercept_id = self._handler_intercepts.pop(callback_id, None) if intercept_id: self._remove_intercept(intercept_id)''', + ''' def expect_request(self, url_or_predicate=None, timeout=30, contexts=None): + """Return a subscription capturing the next matching request. + + The handler is registered immediately, so a request triggered by an + action inside the ``with`` block cannot be missed:: + + with driver.network.expect_request("**/api/**") as request_info: + driver.find_element(By.ID, "load").click() + request = request_info.value + + The captured :class:`Request` is observational: the request is not + intercepted, so its mutation methods must not be used. + + Args: + url_or_predicate: A URL glob (``*``, ``**``, ``?``) matched + against the request URL, or a predicate over the + :class:`Request`. ``None`` matches every request. + timeout: Seconds to wait for a match before raising + ``TimeoutException``. + contexts: Optional browsing context IDs to subscribe to. + + Returns: + A :class:`Subscription` whose ``value`` is the matching + :class:`Request`. + """ + return self._event_manager.expect( + "before_request_sent", + predicate=to_url_predicate(url_or_predicate), + timeout=timeout, + transform=lambda params: Request(self._conn, params), + raw=True, + contexts=contexts, + )''', + ''' def expect_response(self, url_or_predicate=None, timeout=30, contexts=None): + """Return a subscription capturing the next matching completed response. + + The handler is registered immediately, so a response triggered by an + action inside the ``with`` block cannot be missed:: + + with driver.network.expect_response("**/api/**") as response_info: + driver.find_element(By.ID, "load").click() + response = response_info.value + assert response.status == 200 + + The captured :class:`Response` is observational: the response has + already completed, so its mutation methods must not be used. + + Args: + url_or_predicate: A URL glob (``*``, ``**``, ``?``) matched + against the response URL, or a predicate over the + :class:`Response`. ``None`` matches every response. + timeout: Seconds to wait for a match before raising + ``TimeoutException``. + contexts: Optional browsing context IDs to subscribe to. + + Returns: + A :class:`Subscription` whose ``value`` is the matching + :class:`Response`. + """ + return self._event_manager.expect( + "response_completed", + predicate=to_url_predicate(url_or_predicate), + timeout=timeout, + transform=lambda params: Response(self._conn, params), + raw=True, + contexts=contexts, + )''', ], }, "storage": { diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 86e3d11af0341..48326898f794c 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -1191,3 +1191,41 @@ def test_no_event_after_handler_removal(driver): new_events = len(helper.events_received) - events_before assert new_events == 0, f"Expected 0 new events after removal, got {new_events}" + + +def test_expect_user_prompt_captures_prompt(driver, pages): + """The prompt handler is registered before the action that opens it.""" + context_id = driver.current_window_handle + create_alert_page(driver, pages) + + with driver.browsing_context.expect_user_prompt() as prompt_info: + driver.find_element(By.ID, "alert").click() + + prompt = prompt_info.value + assert prompt.type == "alert" + assert prompt.context == context_id + + driver.browsing_context.handle_user_prompt(context=prompt.context, accept=True) + assert "Alerts" in driver.title + + +@pytest.mark.xfail_firefox +def test_expect_download_captures_finished_download(driver, pages, tmp_path): + """DownloadWillBegin and downloadEnd are correlated into one Download.""" + try: + driver.browser.set_download_behavior(allowed=True, destination_folder=tmp_path) + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=driver.current_window_handle, url=url, wait=ReadinessState.COMPLETE) + + with driver.browsing_context.expect_download() as download_info: + driver.find_element(By.ID, "file-1").click() + + download = download_info.value + assert download.failure() is None + assert download.suggested_filename == "file_1.txt" + assert download.path() is not None and download.path().exists() + + download.save_as(tmp_path / "copied.txt") + assert (tmp_path / "copied.txt").exists() + finally: + driver.browser.set_download_behavior(allowed=None) diff --git a/py/test/selenium/webdriver/common/bidi_network_tests.py b/py/test/selenium/webdriver/common/bidi_network_tests.py index 83fb96c594a33..f108b9cb61ea9 100644 --- a/py/test/selenium/webdriver/common/bidi_network_tests.py +++ b/py/test/selenium/webdriver/common/bidi_network_tests.py @@ -503,3 +503,21 @@ def test_extra_headers_compose_with_request_handlers(driver, pages): finally: driver.network.remove_request_handler(handler_id) driver.network.clear_extra_headers() + + +def test_expect_request_captures_request(driver, pages): + with driver.network.expect_request("**/formPage.html") as request_info: + _navigate(driver, pages.url("formPage.html")) + + request = request_info.value + assert "formPage.html" in request.url + assert request.method == "GET" + + +def test_expect_response_captures_response(driver, pages): + with driver.network.expect_response("**/formPage.html") as response_info: + _navigate(driver, pages.url("formPage.html")) + + response = response_info.value + assert "formPage.html" in response.url + assert response.status == 200 diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py index a66842d06d693..7ae43f6d463af 100644 --- a/py/test/selenium/webdriver/common/bidi_script_tests.py +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -1621,3 +1621,21 @@ def test_execute_pinned_script_reports_error(self, driver, pages): assert "missingFunction" in result.error.message finally: driver.script.unpin(pinned) + + +def test_expect_console_message_captures_message(driver, pages): + pages.load("blank.html") + + with driver.script.expect_console_message() as message_info: + driver.execute_script("console.log('expected message');") + + assert message_info.value.text == "expected message" + + +def test_expect_console_message_predicate_filters(driver, pages): + pages.load("blank.html") + + with driver.script.expect_console_message(lambda message: message.text == "needle") as message_info: + driver.execute_script("console.log('haystack'); console.log('needle');") + + assert message_info.value.text == "needle" diff --git a/py/test/unit/selenium/webdriver/common/bidi_expect_events_tests.py b/py/test/unit/selenium/webdriver/common/bidi_expect_events_tests.py new file mode 100644 index 0000000000000..48efa66583df2 --- /dev/null +++ b/py/test/unit/selenium/webdriver/common/bidi_expect_events_tests.py @@ -0,0 +1,464 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import threading + +import pytest + +from selenium.common.exceptions import TimeoutException +from selenium.webdriver.common.bidi.browsing_context import BrowsingContext, Download +from selenium.webdriver.common.bidi.network import Network, Request, Response +from selenium.webdriver.common.bidi.script import ConsoleMessage, Script + + +class FakeConnection: + """Mimics WebSocketConnection: callbacks receive ``event.from_json(params)``.""" + + def __init__(self): + self.commands = [] + self.added_callbacks = [] + self.removed_callbacks = [] + self._next_callback_id = 1 + + def add_callback(self, event_wrapper, callback): + callback_id = self._next_callback_id + self._next_callback_id += 1 + + def _dispatch(params): + callback(event_wrapper.from_json(params)) + + self.added_callbacks.append((callback_id, event_wrapper.event_class, _dispatch)) + return callback_id + + def remove_callback(self, event_wrapper, callback_id): + self.removed_callbacks.append((callback_id, event_wrapper.event_class)) + + def execute(self, cmd): + payload = next(cmd) + self.commands.append(payload) + + if payload["method"] == "session.subscribe": + response = {"subscription": f"subscription-{len(self.commands)}"} + else: + response = {} + + try: + cmd.send(response) + except StopIteration as exc: + return exc.value + + raise AssertionError("BiDi command generator did not finish") + + def commands_named(self, method): + return [c for c in self.commands if c["method"] == method] + + +def dispatch_event_to(conn, event, bidi_event): + """Invoke every subscribed callback for a BiDi event, as the WebSocket would.""" + callbacks = [callback for _, event_class, callback in conn.added_callbacks if event_class == bidi_event] + assert callbacks, f"no event callback registered for {bidi_event}" + for callback in callbacks: + callback(event) + + +def make_before_request_event(url="https://example.com/api/data", method="GET"): + return { + "context": "ctx-1", + "isBlocked": False, + "redirectCount": 0, + "request": { + "request": "req-1", + "url": url, + "method": method, + "headers": [{"name": "accept", "value": {"type": "string", "value": "*/*"}}], + "cookies": [], + "destination": "document", + }, + "timestamp": 1, + } + + +def make_response_completed_event(url="https://example.com/api/data", status=200): + return { + "context": "ctx-1", + "isBlocked": False, + "redirectCount": 0, + "request": { + "request": "req-1", + "url": url, + "method": "GET", + "headers": [], + "cookies": [], + }, + "response": { + "url": url, + "status": status, + "statusText": "OK", + "headers": [{"name": "content-type", "value": {"type": "string", "value": "application/json"}}], + "mimeType": "application/json", + }, + "timestamp": 1, + } + + +def make_console_log_entry(text="hello", level="info"): + return { + "type": "console", + "method": "log", + "level": level, + "text": text, + "args": [{"type": "string", "value": text}], + "source": {"realm": "realm-1", "context": "ctx-1"}, + "timestamp": 1, + } + + +def make_user_prompt_opened_event(message="Sure?", prompt_type="confirm"): + return { + "context": "ctx-1", + "handler": "dismiss", + "message": message, + "type": prompt_type, + "userContext": "default", + } + + +def test_expect_request_captures_matching_request(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request() as request_info: + dispatch_event_to(conn, make_before_request_event(), "network.beforeRequestSent") + + request = request_info.value + assert isinstance(request, Request) + assert request.url == "https://example.com/api/data" + assert request.method == "GET" + + +def test_expect_request_url_glob_skips_non_matching(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request("**/api/**") as request_info: + dispatch_event_to( + conn, make_before_request_event(url="https://example.com/styles.css"), "network.beforeRequestSent" + ) + dispatch_event_to( + conn, make_before_request_event(url="https://example.com/api/users"), "network.beforeRequestSent" + ) + + assert request_info.value.url == "https://example.com/api/users" + + +def test_expect_request_predicate_filters(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request(lambda request: request.method == "POST") as request_info: + dispatch_event_to(conn, make_before_request_event(method="GET"), "network.beforeRequestSent") + dispatch_event_to(conn, make_before_request_event(method="POST"), "network.beforeRequestSent") + + assert request_info.value.method == "POST" + + +def test_expect_response_captures_completed_response(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_response("**/api/**") as response_info: + dispatch_event_to(conn, make_response_completed_event(status=201), "network.responseCompleted") + + response = response_info.value + assert isinstance(response, Response) + assert response.status == 201 + assert response.mime_type == "application/json" + + +def test_expect_request_subscribes_before_action(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + assert any(event_class == "network.beforeRequestSent" for _, event_class, _ in conn.added_callbacks) + assert conn.commands_named("session.subscribe") + subscription.cancel() + + +def test_wait_times_out_without_event(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request(timeout=0.05) + with pytest.raises(TimeoutException): + subscription.wait() + subscription.cancel() + + +def test_with_block_raises_timeout_on_exit(): + conn = FakeConnection() + network = Network(conn) + + with pytest.raises(TimeoutException): + with network.expect_request(timeout=0.05): + pass + + +def test_capture_unsubscribes_handler(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request() as request_info: + dispatch_event_to(conn, make_before_request_event(), "network.beforeRequestSent") + + assert request_info.value is not None + assert any(event_class == "network.beforeRequestSent" for _, event_class in conn.removed_callbacks) + assert conn.commands_named("session.unsubscribe") + + +def test_cancel_unsubscribes_without_waiting(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + subscription.cancel() + + assert any(event_class == "network.beforeRequestSent" for _, event_class in conn.removed_callbacks) + assert conn.commands_named("session.unsubscribe") + + +def test_value_is_cached_after_capture(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request() as request_info: + dispatch_event_to(conn, make_before_request_event(), "network.beforeRequestSent") + + assert request_info.value is request_info.value + + +def test_exception_in_with_block_detaches_and_propagates(): + conn = FakeConnection() + network = Network(conn) + + with pytest.raises(RuntimeError, match="boom"): + with network.expect_request(): + raise RuntimeError("boom") + + assert any(event_class == "network.beforeRequestSent" for _, event_class in conn.removed_callbacks) + + +def test_events_after_capture_are_dropped(): + conn = FakeConnection() + network = Network(conn) + + with network.expect_request() as request_info: + dispatch_event_to(conn, make_before_request_event(url="https://example.com/first"), "network.beforeRequestSent") + + # The fake connection does not actually remove callbacks, so a late event + # still reaches the subscription; the detached guard must drop it. + dispatch_event_to(conn, make_before_request_event(url="https://example.com/late"), "network.beforeRequestSent") + assert request_info.value.url == "https://example.com/first" + + +def test_wait_blocks_until_event_from_other_thread(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + timer = threading.Timer( + 0.05, dispatch_event_to, args=(conn, make_before_request_event(), "network.beforeRequestSent") + ) + timer.start() + try: + assert subscription.wait(timeout=2).url == "https://example.com/api/data" + finally: + timer.join() + + +def test_expect_console_message_captures_message(): + conn = FakeConnection() + script = Script(conn) + + with script.expect_console_message() as message_info: + dispatch_event_to(conn, make_console_log_entry(text="ready"), "log.entryAdded") + + message = message_info.value + assert isinstance(message, ConsoleMessage) + assert message.text == "ready" + + +def test_expect_console_message_predicate_filters(): + conn = FakeConnection() + script = Script(conn) + + with script.expect_console_message(lambda message: message.level == "error") as message_info: + dispatch_event_to(conn, make_console_log_entry(text="noise", level="info"), "log.entryAdded") + dispatch_event_to(conn, make_console_log_entry(text="boom", level="error"), "log.entryAdded") + + assert message_info.value.text == "boom" + + +def test_expect_user_prompt_captures_typed_params(): + conn = FakeConnection() + browsing_context = BrowsingContext(conn) + + with browsing_context.expect_user_prompt() as prompt_info: + dispatch_event_to(conn, make_user_prompt_opened_event(), "browsingContext.userPromptOpened") + + prompt = prompt_info.value + assert prompt.message == "Sure?" + assert prompt.type == "confirm" + assert prompt.context == "ctx-1" + + +def test_expect_download_correlates_begin_and_end(tmp_path): + downloaded = tmp_path / "report.csv" + downloaded.write_text("a,b\n1,2\n") + + conn = FakeConnection() + browsing_context = BrowsingContext(conn) + + with browsing_context.expect_download() as download_info: + dispatch_event_to( + conn, + { + "context": "ctx-1", + "navigation": "nav-1", + "suggestedFilename": "report.csv", + "url": "https://example.com/report", + }, + "browsingContext.downloadWillBegin", + ) + dispatch_event_to( + conn, + { + "context": "ctx-1", + "navigation": "nav-1", + "status": "complete", + "url": "https://example.com/report", + "filepath": str(downloaded), + }, + "browsingContext.downloadEnd", + ) + + download = download_info.value + assert isinstance(download, Download) + assert download.suggested_filename == "report.csv" + assert download.failure() is None + assert download.path() == downloaded + + saved = download.save_as(tmp_path / "saved.csv") + assert (tmp_path / "saved.csv").read_text() == "a,b\n1,2\n" + assert str(saved).endswith("saved.csv") + + +def test_expect_download_removes_both_handlers_after_capture(): + conn = FakeConnection() + browsing_context = BrowsingContext(conn) + + with browsing_context.expect_download() as download_info: + dispatch_event_to( + conn, + {"context": "ctx-1", "navigation": "nav-1", "status": "canceled", "url": "https://example.com/report"}, + "browsingContext.downloadEnd", + ) + + download = download_info.value + assert download.failure() == "canceled" + # No downloadWillBegin was dispatched, so there is no suggested filename. + assert download.suggested_filename is None + removed_events = [event_class for _, event_class in conn.removed_callbacks] + assert "browsingContext.downloadEnd" in removed_events + assert "browsingContext.downloadWillBegin" in removed_events + + +def test_download_save_as_without_file_raises(): + download = Download(status="canceled") + assert download.failure() == "canceled" + assert download.path() is None + with pytest.raises(ValueError, match="no file on disk"): + download.save_as("/tmp/nope") + + +def test_wait_with_zero_timeout_returns_queued_event(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + dispatch_event_to(conn, make_before_request_event(), "network.beforeRequestSent") + + assert subscription.wait(timeout=0).url == "https://example.com/api/data" + + +def test_wait_with_zero_timeout_raises_immediately_without_event(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + with pytest.raises(TimeoutException): + subscription.wait(timeout=0) + subscription.cancel() + + +def test_predicate_exception_drops_event_and_keeps_listening(): + conn = FakeConnection() + network = Network(conn) + + def explosive_predicate(request): + if request.url.endswith("boom"): + raise RuntimeError("predicate blew up") + return True + + with network.expect_request(explosive_predicate) as request_info: + dispatch_event_to(conn, make_before_request_event(url="https://example.com/boom"), "network.beforeRequestSent") + dispatch_event_to(conn, make_before_request_event(url="https://example.com/fine"), "network.beforeRequestSent") + + assert request_info.value.url == "https://example.com/fine" + + +def test_cancel_is_idempotent(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + subscription.cancel() + subscription.cancel() + + removed = [event_class for _, event_class in conn.removed_callbacks if event_class == "network.beforeRequestSent"] + assert len(removed) == 1 + + +def test_concurrent_cancel_detaches_exactly_once(): + conn = FakeConnection() + network = Network(conn) + + subscription = network.expect_request() + barrier = threading.Barrier(2) + + def cancel_after_barrier(): + barrier.wait() + subscription.cancel() + + threads = [threading.Thread(target=cancel_after_barrier) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + removed = [event_class for _, event_class in conn.removed_callbacks if event_class == "network.beforeRequestSent"] + assert len(removed) == 1