diff --git a/README.md b/README.md index dd89cc2a2..e99460db3 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ Playwright is a Python library to automate [Chromium](https://www.chromium.org/H | | Linux | macOS | Windows | | :--- | :---: | :---: | :---: | -| Chromium 130.0.6723.19 | ✅ | ✅ | ✅ | +| Chromium 130.0.6723.31 | ✅ | ✅ | ✅ | | WebKit 18.0 | ✅ | ✅ | ✅ | -| Firefox 130.0 | ✅ | ✅ | ✅ | +| Firefox 131.0 | ✅ | ✅ | ✅ | ## Documentation diff --git a/playwright/_impl/_browser_context.py b/playwright/_impl/_browser_context.py index b713162b7..ff8413a0c 100644 --- a/playwright/_impl/_browser_context.py +++ b/playwright/_impl/_browser_context.py @@ -260,7 +260,7 @@ async def _on_route(self, route: Route) -> None: try: # If the page is closed or unrouteAll() was called without waiting and interception disabled, # the method will throw an error - silence it. - await route._internal_continue(is_internal=True) + await route._inner_continue(True) except Exception: pass diff --git a/playwright/_impl/_connection.py b/playwright/_impl/_connection.py index eb4d182d3..0255dae52 100644 --- a/playwright/_impl/_connection.py +++ b/playwright/_impl/_connection.py @@ -132,6 +132,7 @@ def __init__( self._channel: Channel = Channel(self._connection, self) self._initializer = initializer self._was_collected = False + self._is_internal_type = False self._connection._objects[guid] = self if self._parent: @@ -156,6 +157,9 @@ def _adopt(self, child: "ChannelOwner") -> None: self._objects[child._guid] = child child._parent = self + def mark_as_internal_type(self) -> None: + self._is_internal_type = True + def _set_event_to_subscription_mapping(self, mapping: Dict[str, str]) -> None: self._event_to_subscription_mapping = mapping @@ -353,7 +357,7 @@ def _send_message_to_server( "params": self._replace_channels_with_guids(params), "metadata": metadata, } - if self._tracing_count > 0 and frames and object._guid != "localUtils": + if self._tracing_count > 0 and frames and not object._is_internal_type: self.local_utils.add_stack_to_tracing_no_reply(id, frames) self._transport.send(message) diff --git a/playwright/_impl/_fetch.py b/playwright/_impl/_fetch.py index a4de751bd..93144ac55 100644 --- a/playwright/_impl/_fetch.py +++ b/playwright/_impl/_fetch.py @@ -18,7 +18,6 @@ import typing from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast -from urllib.parse import parse_qs import playwright._impl._network as network from playwright._impl._api_structures import ( @@ -405,7 +404,8 @@ async def _inner_fetch( "fetch", { "url": url, - "params": params_to_protocol(params), + "params": object_to_array(params) if isinstance(params, dict) else None, + "encodedParams": params if isinstance(params, str) else None, "method": method, "headers": serialized_headers, "postData": post_data, @@ -430,23 +430,6 @@ async def storage_state( return result -def params_to_protocol(params: Optional[ParamsType]) -> Optional[List[NameValue]]: - if not params: - return None - if isinstance(params, dict): - return object_to_array(params) - if params.startswith("?"): - params = params[1:] - parsed = parse_qs(params) - if not parsed: - return None - out = [] - for name, values in parsed.items(): - for value in values: - out.append(NameValue(name=name, value=value)) - return out - - def file_payload_to_json(payload: FilePayload) -> ServerFilePayload: return ServerFilePayload( name=payload["name"], diff --git a/playwright/_impl/_local_utils.py b/playwright/_impl/_local_utils.py index 7172ee58a..26a3417c4 100644 --- a/playwright/_impl/_local_utils.py +++ b/playwright/_impl/_local_utils.py @@ -25,6 +25,7 @@ def __init__( self, parent: ChannelOwner, type: str, guid: str, initializer: Dict ) -> None: super().__init__(parent, type, guid, initializer) + self.mark_as_internal_type() self.devices = { device["name"]: parse_device_descriptor(device["descriptor"]) for device in initializer["deviceDescriptors"] diff --git a/playwright/_impl/_network.py b/playwright/_impl/_network.py index ca7123c82..fff605808 100644 --- a/playwright/_impl/_network.py +++ b/playwright/_impl/_network.py @@ -47,7 +47,6 @@ ) from playwright._impl._connection import ( ChannelOwner, - Connection, from_channel, from_nullable_channel, ) @@ -318,6 +317,7 @@ def __init__( self, parent: ChannelOwner, type: str, guid: str, initializer: Dict ) -> None: super().__init__(parent, type, guid, initializer) + self.mark_as_internal_type() self._handling_future: Optional[asyncio.Future["bool"]] = None self._context: "BrowserContext" = cast("BrowserContext", None) self._did_throw = False @@ -350,7 +350,6 @@ async def abort(self, errorCode: str = None) -> None: "abort", { "errorCode": errorCode, - "requestUrl": self.request._initializer["url"], }, ) ) @@ -433,7 +432,6 @@ async def _inner_fulfill( if length and "content-length" not in headers: headers["content-length"] = str(length) params["headers"] = serialize_headers(headers) - params["requestUrl"] = self.request._initializer["url"] await self._race_with_page_close(self._channel.send("fulfill", params)) @@ -492,43 +490,28 @@ async def continue_( async def _inner() -> None: self.request._apply_fallback_overrides(overrides) - await self._internal_continue() + await self._inner_continue(False) return await self._handle_route(_inner) - def _internal_continue( - self, is_internal: bool = False - ) -> Coroutine[Any, Any, None]: - async def continue_route() -> None: - try: - params: Dict[str, Any] = {} - params["url"] = self.request._fallback_overrides.url - params["method"] = self.request._fallback_overrides.method - params["headers"] = self.request._fallback_overrides.headers - if self.request._fallback_overrides.post_data_buffer is not None: - params["postData"] = base64.b64encode( - self.request._fallback_overrides.post_data_buffer - ).decode() - params = locals_to_params(params) - - if "headers" in params: - params["headers"] = serialize_headers(params["headers"]) - params["requestUrl"] = self.request._initializer["url"] - params["isFallback"] = is_internal - await self._connection.wrap_api_call( - lambda: self._race_with_page_close( - self._channel.send( - "continue", - params, - ) - ), - is_internal, - ) - except Exception as e: - if not is_internal: - raise e - - return continue_route() + async def _inner_continue(self, is_fallback: bool = False) -> None: + options = self.request._fallback_overrides + await self._race_with_page_close( + self._channel.send( + "continue", + { + "url": options.url, + "method": options.method, + "headers": serialize_headers(options.headers) + if options.headers + else None, + "postData": base64.b64encode(options.post_data_buffer).decode() + if options.post_data_buffer is not None + else None, + "isFallback": is_fallback, + }, + ) + ) async def _redirected_navigation_request(self, url: str) -> None: await self._handle_route( @@ -586,7 +569,7 @@ def close(self, code: int = None, reason: str = None) -> None: }, ) ) - except: + except Exception: pass def send(self, message: Union[str, bytes]) -> None: @@ -636,7 +619,7 @@ def _channel_message_from_page(self, event: Dict) -> None: elif self._connected: try: asyncio.create_task(self._channel.send("sendToServer", event)) - except: + except Exception: pass def _channel_message_from_server(self, event: Dict) -> None: @@ -649,7 +632,7 @@ def _channel_message_from_server(self, event: Dict) -> None: else: try: asyncio.create_task(self._channel.send("sendToPage", event)) - except: + except Exception: pass def _channel_close_page(self, event: Dict) -> None: @@ -658,7 +641,7 @@ def _channel_close_page(self, event: Dict) -> None: else: try: asyncio.create_task(self._channel.send("closeServer", event)) - except: + except Exception: pass def _channel_close_server(self, event: Dict) -> None: @@ -667,7 +650,7 @@ def _channel_close_server(self, event: Dict) -> None: else: try: asyncio.create_task(self._channel.send("closePage", event)) - except: + except Exception: pass @property @@ -679,7 +662,7 @@ async def close(self, code: int = None, reason: str = None) -> None: await self._channel.send( "closePage", {"code": code, "reason": reason, "wasClean": True} ) - except: + except Exception: pass def connect_to_server(self) -> "WebSocketRoute": @@ -697,7 +680,7 @@ def send(self, message: Union[str, bytes]) -> None: "sendToPage", {"message": message, "isBase64": False} ) ) - except: + except Exception: pass else: try: @@ -710,7 +693,7 @@ def send(self, message: Union[str, bytes]) -> None: }, ) ) - except: + except Exception: pass def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None: @@ -758,9 +741,9 @@ def matches(self, ws_url: str) -> bool: return self.matcher.matches(ws_url) async def handle(self, websocket_route: "WebSocketRoute") -> None: - maybe_future = self.handler(websocket_route) - if maybe_future: - breakpoint() + coro_or_future = self.handler(websocket_route) + if asyncio.iscoroutine(coro_or_future): + await coro_or_future await websocket_route._after_handle() diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index 0dad1e19a..15195b28b 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -325,7 +325,7 @@ async def _on_web_socket_route(self, web_socket_route: WebSocketRoute) -> None: if route_handler: await route_handler.handle(web_socket_route) else: - web_socket_route.connect_to_server() + await self._browser_context._on_web_socket_route(web_socket_route) def _on_binding(self, binding_call: "BindingCall") -> None: func = self._bindings.get(binding_call._initializer["name"]) diff --git a/playwright/_impl/_tracing.py b/playwright/_impl/_tracing.py index b2d4b5df9..f4c6b31b1 100644 --- a/playwright/_impl/_tracing.py +++ b/playwright/_impl/_tracing.py @@ -41,13 +41,10 @@ async def start( params = locals_to_params(locals()) self._include_sources = bool(sources) - async def _inner_start() -> str: - await self._channel.send("tracingStart", params) - return await self._channel.send( - "tracingStartChunk", {"title": title, "name": name} - ) - - trace_name = await self._connection.wrap_api_call(_inner_start, True) + await self._channel.send("tracingStart", params) + trace_name = await self._channel.send( + "tracingStartChunk", {"title": title, "name": name} + ) await self._start_collecting_stacks(trace_name) async def start_chunk(self, title: str = None, name: str = None) -> None: @@ -64,14 +61,11 @@ async def _start_collecting_stacks(self, trace_name: str) -> None: ) async def stop_chunk(self, path: Union[pathlib.Path, str] = None) -> None: - await self._connection.wrap_api_call(lambda: self._do_stop_chunk(path), True) + await self._do_stop_chunk(path) async def stop(self, path: Union[pathlib.Path, str] = None) -> None: - async def _inner() -> None: - await self._do_stop_chunk(path) - await self._channel.send("tracingStop") - - await self._connection.wrap_api_call(_inner, True) + await self._do_stop_chunk(path) + await self._channel.send("tracingStop") async def _do_stop_chunk(self, file_path: Union[pathlib.Path, str] = None) -> None: self._reset_stack_counter() diff --git a/setup.py b/setup.py index 5492f879e..a7da81984 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ InWheel = None from wheel.bdist_wheel import bdist_wheel as BDistWheelCommand -driver_version = "1.48.0-alpha-1727434891000" +driver_version = "1.48.0-beta-1728034490000" def extractall(zip: zipfile.ZipFile, path: str) -> None: diff --git a/tests/async/test_page_request_gc.py b/tests/async/test_page_request_gc.py index d9dc50339..7d0cce9ef 100644 --- a/tests/async/test_page_request_gc.py +++ b/tests/async/test_page_request_gc.py @@ -25,8 +25,10 @@ async def test_should_work(page: Page, server: Server) -> None: ) await page.request_gc() assert await page.evaluate("() => globalThis.weakRef.deref()") == {"hello": "world"} + await page.request_gc() assert await page.evaluate("() => globalThis.weakRef.deref()") == {"hello": "world"} + await page.evaluate("() => globalThis.objectToDestroy = null") await page.request_gc() assert await page.evaluate("() => globalThis.weakRef.deref()") is None diff --git a/tests/async/test_route_web_socket.py b/tests/async/test_route_web_socket.py index 6de9f6441..f2870612a 100644 --- a/tests/async/test_route_web_socket.py +++ b/tests/async/test_route_web_socket.py @@ -14,87 +14,135 @@ import asyncio import re -from typing import Dict - -from playwright.async_api import Page, WebSocketRoute -from tests.server import Server - -# test('should work with ws.close', async ({ page, server }) => { -# const { promise, resolve } = withResolvers(); -# await page.routeWebSocket(/.*/, async ws => { -# ws.connectToServer(); -# resolve(ws); -# }); - -# const wsPromise = server.waitForWebSocket(); -# await setupWS(page, server.PORT, 'blob'); -# const ws = await wsPromise; - -# const route = await promise; -# route.send('hello'); -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([ -# 'open', -# `message: data=hello origin=ws://localhost:${server.PORT} lastEventId=`, -# ]); - -# const closedPromise = new Promise(f => ws.once('close', (code, reason) => f({ code, reason: reason.toString() }))); -# await route.close({ code: 3009, reason: 'oops' }); -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([ -# 'open', -# `message: data=hello origin=ws://localhost:${server.PORT} lastEventId=`, -# 'close code=3009 reason=oops wasClean=true', -# ]); -# expect(await closedPromise).toEqual({ code: 3009, reason: 'oops' }); -# }); - - -async def test_should_work_with_ws_close(page: Page, server: Server): - future = asyncio.Future() - - def _handle_ws(ws: WebSocketRoute): - ws.connect_to_server +from typing import Any, Awaitable, Callable, Literal, Tuple, Union + +from playwright.async_api import Frame, Page, WebSocketRoute +from tests.server import Server, WebSocketProtocol + + +async def assert_equal( + actual_cb: Callable[[], Union[Any, Awaitable[Any]]], expected: Any +) -> None: + __tracebackhide__ = True + start_time = asyncio.get_event_loop().time() + attempts = 0 + while True: + actual = actual_cb() + if asyncio.iscoroutine(actual): + actual = await actual + if actual == expected: + return + attempts += 1 + if asyncio.get_event_loop().time() - start_time > 5: + raise TimeoutError(f"Timed out after 10 seconds. Last actual was: {actual}") + await asyncio.sleep(0.2) + + +async def setup_ws( + target: Union[Page, Frame], + port: int, + protocol: Union[Literal["blob"], Literal["arraybuffer"]], +) -> None: + await target.goto("about:blank") + await target.evaluate( + """({ port, binaryType }) => { + window.log = []; + window.ws = new WebSocket('ws://localhost:' + port + '/ws'); + window.ws.binaryType = binaryType; + window.ws.addEventListener('open', () => window.log.push('open')); + window.ws.addEventListener('close', event => window.log.push(`close code=${event.code} reason=${event.reason} wasClean=${event.wasClean}`)); + window.ws.addEventListener('error', event => window.log.push(`error`)); + window.ws.addEventListener('message', async event => { + let data; + if (typeof event.data === 'string') + data = event.data; + else if (event.data instanceof Blob) + data = 'blob:' + await event.data.text(); + else + data = 'arraybuffer:' + await (new Blob([event.data])).text(); + window.log.push(`message: data=${data} origin=${event.origin} lastEventId=${event.lastEventId}`); + }); + window.wsOpened = new Promise(f => window.ws.addEventListener('open', () => f())); + }""", + {"port": port, "binaryType": protocol}, + ) + + +async def test_should_work_with_ws_close(page: Page, server: Server) -> None: + future: asyncio.Future[WebSocketRoute] = asyncio.Future() + + def _handle_ws(ws: WebSocketRoute) -> None: + ws.connect_to_server() future.set_result(ws) await page.route_web_socket(re.compile(".*"), _handle_ws) ws_task = server.wait_for_web_socket() + await setup_ws(page, server.PORT, "blob") + ws = await ws_task + + route = await future + route.send("hello") + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=hello origin=ws://localhost:{server.PORT} lastEventId=", + ], + ) + + closed_promise: asyncio.Future[Tuple[int, str]] = asyncio.Future() + ws.events.once( + "close", lambda code, reason: closed_promise.set_result((code, reason)) + ) + await route.close(code=3009, reason="oops") + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=hello origin=ws://localhost:{server.PORT} lastEventId=", + "close code=3009 reason=oops wasClean=true", + ], + ) + assert await closed_promise == (3009, "oops") + + +async def test_should_pattern_match(page: Page, server: Server) -> None: + await page.route_web_socket( + re.compile(r".*/ws$"), lambda ws: ws.connect_to_server() + ) + await page.route_web_socket( + "**/mock-ws", lambda ws: ws.on_message(lambda message: ws.send("mock-response")) + ) + ws_task = server.wait_for_web_socket() + await page.goto("about:blank") + await page.evaluate( + """async ({ port }) => { + window.log = []; + window.ws1 = new WebSocket('ws://localhost:' + port + '/ws'); + window.ws1.addEventListener('message', event => window.log.push(`ws1:${event.data}`)); + window.ws2 = new WebSocket('ws://localhost:' + port + '/something/something/mock-ws'); + window.ws2.addEventListener('message', event => window.log.push(`ws2:${event.data}`)); + await Promise.all([ + new Promise(f => window.ws1.addEventListener('open', f)), + new Promise(f => window.ws2.addEventListener('open', f)), + ]); + }""", + {"port": server.PORT}, + ) + + ws = await ws_task + ws.events.on("message", lambda payload, isBinary: ws.sendMessage(b"response")) + + await page.evaluate("window.ws1.send('request')") + await assert_equal(lambda: page.evaluate("window.log"), ["ws1:response"]) + + await page.evaluate("window.ws2.send('request')") + await assert_equal( + lambda: page.evaluate("window.log"), ["ws1:response", "ws2:mock-response"] + ) -# test('should pattern match', async ({ page, server }) => { -# await page.routeWebSocket(/.*\/ws$/, async ws => { -# ws.connectToServer(); -# }); - -# await page.routeWebSocket('**/mock-ws', ws => { -# ws.onMessage(message => { -# ws.send('mock-response'); -# }); -# }); - -# const wsPromise = server.waitForWebSocket(); - -# await page.goto('about:blank'); -# await page.evaluate(async ({ port }) => { -# window.log = []; -# (window as any).ws1 = new WebSocket('ws://localhost:' + port + '/ws'); -# (window as any).ws1.addEventListener('message', event => window.log.push(`ws1:${event.data}`)); -# (window as any).ws2 = new WebSocket('ws://localhost:' + port + '/something/something/mock-ws'); -# (window as any).ws2.addEventListener('message', event => window.log.push(`ws2:${event.data}`)); -# await Promise.all([ -# new Promise(f => (window as any).ws1.addEventListener('open', f)), -# new Promise(f => (window as any).ws2.addEventListener('open', f)), -# ]); -# }, { port: server.PORT }); - -# const ws = await wsPromise; -# ws.on('message', () => ws.send('response')); - -# await page.evaluate(() => (window as any).ws1.send('request')); -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([`ws1:response`]); - -# await page.evaluate(() => (window as any).ws2.send('request')); -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([`ws1:response`, `ws2:mock-response`]); -# }); # test('should work with server', async ({ page, server }) => { # const { promise, resolve } = withResolvers(); @@ -182,43 +230,179 @@ def _handle_ws(ws: WebSocketRoute): # await expect.poll(() => log).toEqual(['message: fake', 'message: modified', 'message: pass-client', 'message: pass-client-2', 'close: code=3009 reason=problem']); # }); -# test('should work without server', async ({ page, server }) => { -# const { promise, resolve } = withResolvers(); -# await page.routeWebSocket(/.*/, ws => { -# ws.onMessage(message => { -# switch (message) { -# case 'to-respond': -# ws.send('response'); -# return; -# } -# }); -# resolve(ws); -# }); -# await setupWS(page, server.PORT, 'blob'); +async def test_should_work_with_server(page: Page, server: Server) -> None: + future: asyncio.Future[WebSocketRoute] = asyncio.Future() -# await page.evaluate(async () => { -# await window.wsOpened; -# window.ws.send('to-respond'); -# window.ws.send('to-block'); -# window.ws.send('to-respond'); -# }); + async def _handle_ws(ws: WebSocketRoute) -> None: + server = ws.connect_to_server() -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([ -# 'open', -# `message: data=response origin=ws://localhost:${server.PORT} lastEventId=`, -# `message: data=response origin=ws://localhost:${server.PORT} lastEventId=`, -# ]); + def _ws_on_message(message: Union[str, bytes]) -> None: + if message == "to-respond": + ws.send("response") + return + if message == "to-block": + return + if message == "to-modify": + server.send("modified") + return + server.send(message) -# const route = await promise; -# route.send('another'); -# await route.close({ code: 3008, reason: 'oops' }); + ws.on_message(_ws_on_message) -# await expect.poll(() => page.evaluate(() => window.log)).toEqual([ -# 'open', -# `message: data=response origin=ws://localhost:${server.PORT} lastEventId=`, -# `message: data=response origin=ws://localhost:${server.PORT} lastEventId=`, -# `message: data=another origin=ws://localhost:${server.PORT} lastEventId=`, -# 'close code=3008 reason=oops wasClean=true', -# ]); -# }); + def _server_on_message(message: Union[str, bytes]) -> None: + if message == "to-block": + return + if message == "to-modify": + ws.send("modified") + return + ws.send(message) + + server.on_message(_server_on_message) + server.send("fake") + future.set_result(ws) + + await page.route_web_socket(re.compile(".*"), _handle_ws) + ws_task = server.wait_for_web_socket() + log = [] + + def _once_web_socket_connection(ws: WebSocketProtocol) -> None: + ws.events.on( + "message", lambda data, is_binary: log.append(f"message: {data.decode()}") + ) + ws.events.on( + "close", + lambda code, reason: log.append(f"close: code={code} reason={reason}"), + ) + + server.once_web_socket_connection(_once_web_socket_connection) + + await setup_ws(page, server.PORT, "blob") + ws = await ws_task + await assert_equal(lambda: log, ["message: fake"]) + + ws.sendMessage(b"to-modify") + ws.sendMessage(b"to-block") + ws.sendMessage(b"pass-server") + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=modified origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=pass-server origin=ws://localhost:{server.PORT} lastEventId=", + ], + ) + + await page.evaluate( + """() => { + window.ws.send('to-respond'); + window.ws.send('to-modify'); + window.ws.send('to-block'); + window.ws.send('pass-client'); + }""" + ) + await assert_equal( + lambda: log, ["message: fake", "message: modified", "message: pass-client"] + ) + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=modified origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=pass-server origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + ], + ) + + route = await future + route.send("another") + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=modified origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=pass-server origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=another origin=ws://localhost:{server.PORT} lastEventId=", + ], + ) + + await page.evaluate( + """() => { + window.ws.send('pass-client-2'); + }""" + ) + await assert_equal( + lambda: log, + [ + "message: fake", + "message: modified", + "message: pass-client", + "message: pass-client-2", + ], + ) + + await page.evaluate( + """() => { + window.ws.close(3009, 'problem'); + }""" + ) + await assert_equal( + lambda: log, + [ + "message: fake", + "message: modified", + "message: pass-client", + "message: pass-client-2", + "close: code=3009 reason=problem", + ], + ) + + +async def test_should_work_without_server(page: Page, server: Server) -> None: + future: asyncio.Future[WebSocketRoute] = asyncio.Future() + + async def _handle_ws(ws: WebSocketRoute) -> None: + def _ws_on_message(message: Union[str, bytes]) -> None: + if message == "to-respond": + ws.send("response") + + ws.on_message(_ws_on_message) + future.set_result(ws) + + await page.route_web_socket(re.compile(".*"), _handle_ws) + await setup_ws(page, server.PORT, "blob") + + await page.evaluate( + """async () => { + await window.wsOpened; + window.ws.send('to-respond'); + window.ws.send('to-block'); + window.ws.send('to-respond'); + }""" + ) + + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + ], + ) + + route = await future + route.send("another") + # wait for the message to be processed + await page.wait_for_timeout(100) + await route.close(code=3008, reason="oops") + await assert_equal( + lambda: page.evaluate("window.log"), + [ + "open", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=response origin=ws://localhost:{server.PORT} lastEventId=", + f"message: data=another origin=ws://localhost:{server.PORT} lastEventId=", + "close code=3008 reason=oops wasClean=true", + ], + ) diff --git a/tests/server.py b/tests/server.py index 23d7ff374..f963d8a60 100644 --- a/tests/server.py +++ b/tests/server.py @@ -32,6 +32,7 @@ Set, Tuple, TypeVar, + Union, cast, ) from urllib.parse import urlparse @@ -39,6 +40,7 @@ from autobahn.twisted.resource import WebSocketResource from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol from OpenSSL import crypto +from pyee import EventEmitter from twisted.internet import reactor as _twisted_reactor from twisted.internet import ssl from twisted.internet.selectreactor import SelectReactor @@ -197,6 +199,11 @@ async def wait_for_request(self, path: str) -> TestServerRequest: self.request_subscribers[path] = future return await future + def wait_for_web_socket(self) -> 'asyncio.Future["WebSocketProtocol"]': + future: asyncio.Future[WebSocketProtocol] = asyncio.Future() + self.once_web_socket_connection(future.set_result) + return future + @contextlib.contextmanager def expect_request( self, path: str @@ -280,6 +287,21 @@ def listen(self, factory: http.HTTPFactory) -> None: class WebSocketProtocol(WebSocketServerProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.events = EventEmitter() + + def onClose(self, wasClean: bool, code: int, reason: str) -> None: + super().onClose(wasClean, code, reason) + self.events.emit( + "close", + code, + reason, + ) + + def onMessage(self, payload: Union[str, bytes], isBinary: bool) -> None: + self.events.emit("message", payload, isBinary) + def onOpen(self) -> None: for handler in self.factory.server_instance._ws_handlers.copy(): self.factory.server_instance._ws_handlers.remove(handler)