diff --git a/playwright/_impl/_helper.py b/playwright/_impl/_helper.py index d0737be07..538d5533a 100644 --- a/playwright/_impl/_helper.py +++ b/playwright/_impl/_helper.py @@ -34,7 +34,7 @@ Union, cast, ) -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse from playwright._impl._api_structures import NameValue from playwright._impl._errors import ( @@ -157,6 +157,10 @@ def url_matches( base_url = re.sub(r"^http", "ws", base_url) if base_url: match = urljoin(base_url, match) + parsed = urlparse(match) + if parsed.path == "": + parsed = parsed._replace(path="/") + match = parsed.geturl() if isinstance(match, str): match = glob_to_regex(match) if isinstance(match, Pattern): diff --git a/tests/async/test_route_web_socket.py b/tests/async/test_route_web_socket.py index 2ebda4b9e..465832adf 100644 --- a/tests/async/test_route_web_socket.py +++ b/tests/async/test_route_web_socket.py @@ -346,3 +346,36 @@ async def _handle_ws(ws: WebSocketRoute) -> None: f"message: data=echo origin=ws://localhost:{server.PORT} lastEventId=", ], ) + + +async def test_should_work_with_no_trailing_slash(page: Page, server: Server) -> None: + log: list[str] = [] + + async def handle_ws(ws: WebSocketRoute) -> None: + def on_message(message: Union[str, bytes]) -> None: + assert isinstance(message, str) + log.append(message) + ws.send("response") + + ws.on_message(on_message) + + # No trailing slash in the route pattern + await page.route_web_socket(f"ws://localhost:{server.PORT}", handle_ws) + + await page.goto("about:blank") + await page.evaluate( + """({ port }) => { + window.log = []; + // No trailing slash in WebSocket URL + window.ws = new WebSocket('ws://localhost:' + port); + window.ws.addEventListener('message', event => window.log.push(event.data)); + }""", + {"port": server.PORT}, + ) + + await assert_equal( + lambda: page.evaluate("window.ws.readyState"), 1 # WebSocket.OPEN + ) + await page.evaluate("window.ws.send('query')") + await assert_equal(lambda: log, ["query"]) + await assert_equal(lambda: page.evaluate("window.log"), ["response"]) diff --git a/tests/sync/test_route_web_socket.py b/tests/sync/test_route_web_socket.py index a22a6e883..2e97ebd8d 100644 --- a/tests/sync/test_route_web_socket.py +++ b/tests/sync/test_route_web_socket.py @@ -340,3 +340,34 @@ def _handle_ws(ws: WebSocketRoute) -> None: f"message: data=echo origin=ws://localhost:{server.PORT} lastEventId=", ], ) + + +def test_should_work_with_no_trailing_slash(page: Page, server: Server) -> None: + log: list[str] = [] + + async def handle_ws(ws: WebSocketRoute) -> None: + def on_message(message: Union[str, bytes]) -> None: + assert isinstance(message, str) + log.append(message) + ws.send("response") + + ws.on_message(on_message) + + # No trailing slash in the route pattern + page.route_web_socket(f"ws://localhost:{server.PORT}", handle_ws) + + page.goto("about:blank") + page.evaluate( + """({ port }) => { + window.log = []; + // No trailing slash in WebSocket URL + window.ws = new WebSocket('ws://localhost:' + port); + window.ws.addEventListener('message', event => window.log.push(event.data)); + }""", + {"port": server.PORT}, + ) + + assert_equal(lambda: page.evaluate("window.ws.readyState"), 1) # WebSocket.OPEN + page.evaluate("window.ws.send('query')") + assert_equal(lambda: log, ["query"]) + assert_equal(lambda: page.evaluate("window.log"), ["response"])