diff --git a/py/selenium/webdriver/common/bidi/bidi.py b/py/selenium/webdriver/common/bidi/bidi.py new file mode 100644 index 0000000000000..65fdff582d936 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/bidi.py @@ -0,0 +1,64 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import is_dataclass + + +@dataclass +class BidiObject: + def to_json(self): + json = {} + for field in fields(self): + value = getattr(self, field.name) + if value is None: + continue + if is_dataclass(value): + value = value.to_json() + elif isinstance(value, list): + value = [v.to_json() if hasattr(v, "to_json") else v for v in value] + elif isinstance(value, dict): + value = {k: v.to_json() if hasattr(v, "to_json") else v for k, v in value.items()} + key = field.name[1:] if field.name.startswith("_") else field.name + json[key] = value + return json + + @classmethod + def from_json(cls, json): + return cls(**json) + + +@dataclass +class BidiEvent(BidiObject): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_json(cls, json): + params = cls.param_class.from_json(json) + return cls(params) + + +@dataclass +class BidiCommand(BidiObject): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def cmd(self): + result = yield self.to_json() + return result diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py new file mode 100644 index 0000000000000..8e93cdd9f52fa --- /dev/null +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -0,0 +1,41 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing +from dataclasses import dataclass + +from .bidi import BidiCommand +from .bidi import BidiObject + +BrowsingContext = str + +Navigation = str + +ReadinessState = typing.Literal["none", "interactive", "complete"] + + +@dataclass +class NavigateParameters(BidiObject): + context: BrowsingContext + url: str + wait: typing.Optional[ReadinessState] = None + + +@dataclass +class Navigate(BidiCommand): + params: NavigateParameters + method: typing.Literal["browsingContext.navigate"] = "browsingContext.navigate" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py new file mode 100644 index 0000000000000..1b398aa2d355b --- /dev/null +++ b/py/selenium/webdriver/common/bidi/network.py @@ -0,0 +1,227 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing +from dataclasses import dataclass + +from selenium.webdriver.common.bidi.cdp import import_devtools + +from . import browsing_context +from . import script +from .bidi import BidiCommand +from .bidi import BidiEvent +from .bidi import BidiObject + +devtools = import_devtools("") +event_class = devtools.util.event_class + +InterceptPhase = typing.Literal["beforeRequestSent", "responseStarted", "authRequired"] + + +@dataclass +class UrlPatternPattern(BidiObject): + _type: typing.Literal["pattern"] = "pattern" + protocol: typing.Optional[str] = None + hostname: typing.Optional[str] = None + port: typing.Optional[str] = None + pathname: typing.Optional[str] = None + search: typing.Optional[str] = None + + +@dataclass +class UrlPatternString(BidiObject): + pattern: str + _type: typing.Literal["string"] = "string" + + +UrlPattern = typing.Union[UrlPatternPattern, UrlPatternString] + + +@dataclass +class AddInterceptParameters(BidiObject): + phases: typing.List[InterceptPhase] + contexts: typing.Optional[typing.List[browsing_context.BrowsingContext]] = None + urlPatterns: typing.Optional[typing.List[UrlPattern]] = None + + +@dataclass +class AddIntercept(BidiCommand): + params: AddInterceptParameters + method: typing.Literal["network.addIntercept"] = "network.addIntercept" + + +Request = str + + +@dataclass +class StringValue(BidiObject): + value: str + _type: typing.Literal["string"] = "string" + + +@dataclass +class Base64Value(BidiObject): + value: str + _type: typing.Literal["base64"] = "base64" + + +BytesValue = typing.Union[StringValue, Base64Value] + + +@dataclass +class Header(BidiObject): + name: str + value: BytesValue + + +SameSite = typing.Literal["strict", "lax", "none"] + + +@dataclass +class Cookie(BidiObject): + name: str + value: BytesValue + domain: str + path: str + size: int + httpOnly: bool + secure: bool + sameSite: SameSite + expiry: typing.Optional[int] = None + + +@dataclass +class FetchTimingInfo(BidiObject): + timeOrigin: float + requestTime: float + redirectStart: float + redirectEnd: float + fetchStart: float + dnsStart: float + dnsEnd: float + connectStart: float + connectEnd: float + tlsStart: float + requestStart: float + responseStart: float + responseEnd: float + + +@dataclass +class RequestData(BidiObject): + request: Request + url: str + method: str + headersSize: int + timings: FetchTimingInfo + headers: typing.Optional[typing.List[Header]] = None + cookies: typing.Optional[typing.List[Cookie]] = None + bodySize: typing.Optional[int] = None + + +Intercept = str + + +@dataclass +class Initiator(BidiObject): + _type: typing.Literal["parser", "script", "preflight", "other"] + columnNumber: typing.Optional[int] = None + lineNumber: typing.Optional[int] = None + stackTrace: typing.Optional[script.StackTrace] = None + request: typing.Optional[Request] = None + + +@dataclass +class BeforeRequestSentParameters(BidiObject): + isBlocked: bool + redirectCount: int + request: RequestData + timestamp: int + initiator: Initiator + context: typing.Optional[browsing_context.BrowsingContext] = None + navigation: typing.Optional[browsing_context.Navigation] = None + intercepts: typing.Optional[typing.List[Intercept]] = None + + +@dataclass +@event_class("network.beforeRequestSent") +class BeforeRequestSent(BidiEvent): + params: BeforeRequestSentParameters + method: typing.Literal["network.beforeRequestSent"] = "network.beforeRequestSent" + + param_class = BeforeRequestSentParameters + + +@dataclass +class CookieHeader(BidiObject): + name: str + value: BytesValue + + +@dataclass +class ContinueRequestParameters(BidiObject): + request: Request + body: typing.Optional[BytesValue] = None + cookies: typing.Optional[typing.List[CookieHeader]] = None + headers: typing.Optional[typing.List[Header]] = None + method: typing.Optional[str] = None + url: typing.Optional[str] = None + + +@dataclass +class ContinueRequest(BidiCommand): + params: ContinueRequestParameters + method: typing.Literal["network.continueRequest"] = "network.continueRequest" + + +@dataclass +class RemoveInterceptParameters(BidiObject): + intercept: Intercept + + +@dataclass +class RemoveIntercept(BidiCommand): + params: RemoveInterceptParameters + method: typing.Literal["network.removeIntercept"] = "network.removeIntercept" + + +@dataclass +class SetCacheBehaviorParameters(BidiObject): + cacheBehavior: typing.Literal["default", "bypass"] + contexts: typing.Optional[typing.List[browsing_context.BrowsingContext]] = None + + +@dataclass +class SetCacheBehavior(BidiCommand): + params: SetCacheBehaviorParameters + method: typing.Literal["network.setCacheBehavior"] = "network.setCacheBehavior" + + +class Network: + def __init__(self, conn): + self.conn = conn + + async def add_intercept(self, params: AddInterceptParameters): + result = await self.conn.execute(AddIntercept(params).cmd()) + return result + + async def continue_request(self, params: ContinueRequestParameters): + result = await self.conn.execute(ContinueRequest(params).cmd()) + return result + + async def remove_intercept(self, params: RemoveInterceptParameters): + await self.conn.execute(RemoveIntercept(params).cmd()) diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 6819a5cf63436..0263634e88fad 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -18,6 +18,7 @@ import typing from dataclasses import dataclass +from .bidi import BidiObject from .session import session_subscribe from .session import session_unsubscribe @@ -108,3 +109,16 @@ def from_json(cls, json): stacktrace=json["stackTrace"], type_=json["type"], ) + + +@dataclass +class StackFrame(BidiObject): + columnNumber: int + functionName: str + lineNumber: int + url: str + + +@dataclass +class StackTrace(BidiObject): + callFrames: typing.Optional[typing.List[StackFrame]] = None diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index dbe5d26644a87..eab41486d52f2 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -26,7 +26,7 @@ def session_subscribe(*events, browsing_contexts=None): if browsing_contexts is None: browsing_contexts = [] if browsing_contexts: - cmd_dict["params"]["browsingContexts"] = browsing_contexts + cmd_dict["params"]["browsing_contexts"] = browsing_contexts _ = yield cmd_dict return None @@ -41,6 +41,6 @@ def session_unsubscribe(*events, browsing_contexts=None): if browsing_contexts is None: browsing_contexts = [] if browsing_contexts: - cmd_dict["params"]["browsingContexts"] = browsing_contexts + cmd_dict["params"]["browsing_contexts"] = browsing_contexts _ = yield cmd_dict return None diff --git a/py/selenium/webdriver/remote/network.py b/py/selenium/webdriver/remote/network.py new file mode 100644 index 0000000000000..f7c746c374f5a --- /dev/null +++ b/py/selenium/webdriver/remote/network.py @@ -0,0 +1,125 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from collections import defaultdict +from contextlib import asynccontextmanager + +import trio + +from selenium.webdriver.common.bidi import network +from selenium.webdriver.common.bidi.browsing_context import Navigate +from selenium.webdriver.common.bidi.browsing_context import NavigateParameters +from selenium.webdriver.common.bidi.cdp import open_cdp +from selenium.webdriver.common.bidi.network import AddInterceptParameters +from selenium.webdriver.common.bidi.network import BeforeRequestSent +from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters +from selenium.webdriver.common.bidi.network import ContinueRequestParameters +from selenium.webdriver.common.bidi.session import session_subscribe +from selenium.webdriver.common.bidi.session import session_unsubscribe + + +class Network: + def __init__(self, driver): + self.driver = driver + self.listeners = {} + self.intercepts = defaultdict(lambda: {"event_name": None, "handlers": []}) + self.bidi_network = None + self.conn = None + self.nursery = None + + self.remove_request_handler = self.remove_intercept + self.clear_request_handlers = self.clear_intercepts + + @asynccontextmanager + async def set_context(self): + ws_url = self.driver.caps.get("webSocketUrl") + async with open_cdp(ws_url) as conn: + self.conn = conn + self.bidi_network = network.Network(conn) + async with trio.open_nursery() as nursery: + self.nursery = nursery + yield + + async def get(self, url, wait="complete"): + params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait) + await self.conn.execute(Navigate(params).cmd()) + + async def add_listener(self, event, callback): + event_name = event.event_class + if event_name in self.listeners: + return + self.listeners[event_name] = self.conn.listen(event) + try: + async for event in self.listeners[event_name]: + request_data = event.params + if request_data.isBlocked: + await callback(request_data) + except trio.ClosedResourceError: + pass + + async def add_handler(self, event, handler, urlPatterns=None): + event_name = event.event_class + phase_name = event_name.split(".")[-1] + + await self.conn.execute(session_subscribe(event_name)) + + params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns) + result = await self.bidi_network.add_intercept(params) + intercept = result["intercept"] + + self.intercepts[intercept]["event_name"] = event_name + self.intercepts[intercept]["handlers"].append(handler) + self.nursery.start_soon(self.add_listener, event, self.handle_events) + return intercept + + async def add_request_handler(self, handler, urlPatterns=None): + intercept = await self.add_handler(BeforeRequestSent, handler, urlPatterns) + return intercept + + async def handle_events(self, event_params): + if isinstance(event_params, BeforeRequestSentParameters): + json = self.handle_requests(event_params) + params = ContinueRequestParameters(**json) + await self.bidi_network.continue_request(params) + + def handle_requests(self, params): + request = params.request + for intercept in params.intercepts: + for handler in self.intercepts[intercept]["handlers"]: + request = handler(request) + return request + + async def remove_listener(self, event_name): + listener = self.listeners.pop(event_name) + listener.close() + + async def remove_intercept(self, intercept): + await self.bidi_network.remove_intercept( + params=network.RemoveInterceptParameters(intercept), + ) + event_name = self.intercepts.pop(intercept)["event_name"] + remaining = [i for i in self.intercepts.values() if i["event_name"] == event_name] + if len(remaining) == 0: + await self.remove_listener(event_name) + await self.conn.execute(session_unsubscribe(event_name)) + + async def clear_intercepts(self): + for intercept in self.intercepts: + await self.remove_intercept(intercept) + + async def disable_cache(self): + # Bidi 'network.setCacheBehavior' is not implemented in v130 + self.driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True}) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index bae7f4e8d28c1..01d3e47aa256a 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -61,6 +61,7 @@ from .file_detector import LocalFileDetector from .locator_converter import LocatorConverter from .mobile import Mobile +from .network import Network from .remote_connection import RemoteConnection from .script_key import ScriptKey from .shadowroot import ShadowRoot @@ -239,6 +240,7 @@ def __init__( self._websocket_connection = None self._script = None + self._network = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1090,6 +1092,13 @@ def script(self): return self._script + @property + def network(self): + if not self._network: + self._network = Network(self) + + return self._network + def _start_bidi(self): if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") diff --git a/py/test/selenium/webdriver/common/bidi_network_tests.py b/py/test/selenium/webdriver/common/bidi_network_tests.py new file mode 100644 index 0000000000000..158d05990a796 --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_network_tests.py @@ -0,0 +1,57 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.common.bidi.network import UrlPatternString + + +@pytest.mark.xfail_firefox +@pytest.mark.xfail_safari +@pytest.mark.xfail_edge +async def test_request_handler(driver, pages): + + url1 = pages.url("simpleTest.html") + url2 = pages.url("clicks.html") + url3 = pages.url("formPage.html") + + pattern1 = [UrlPatternString(url1)] + pattern2 = [UrlPatternString(url2)] + + def request_handler(params): + request = params["request"] + json = {"request": request, "url": url3} + return json + + async with driver.network.set_context(): + # Multiple intercepts + intercept1 = await driver.network.add_request_handler(request_handler, pattern1) + intercept2 = await driver.network.add_request_handler(request_handler, pattern2) + await driver.network.get(url1) + assert driver.title == "We Leave From Here" + await driver.network.get(url2) + assert driver.title == "We Leave From Here" + + # Removal of a single intercept + await driver.network.remove_intercept(intercept2) + await driver.network.get(url2) + assert driver.title == "clicks" + await driver.network.get(url1) + assert driver.title == "We Leave From Here" + + await driver.network.remove_intercept(intercept1) + await driver.network.get(url1) + assert driver.title == "Hello WebDriver"