Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
mxschmitt committed Oct 2, 2024
1 parent d9cdfbb commit 3ed5b5c
Show file tree
Hide file tree
Showing 17 changed files with 1,271 additions and 97 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Playwright is a Python library to automate [Chromium](https://www.chromium.org/H

| | Linux | macOS | Windows |
| :--- | :---: | :---: | :---: |
| Chromium <!-- GEN:chromium-version -->129.0.6668.29<!-- GEN:stop --> ||||
| Chromium <!-- GEN:chromium-version -->130.0.6723.19<!-- GEN:stop --> ||||
| WebKit <!-- GEN:webkit-version -->18.0<!-- GEN:stop --> ||||
| Firefox <!-- GEN:firefox-version -->130.0<!-- GEN:stop --> ||||

Expand Down
53 changes: 51 additions & 2 deletions playwright/_impl/_browser_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,22 @@
TimeoutSettings,
URLMatch,
URLMatcher,
WebSocketRouteHandlerCallback,
async_readfile,
async_writefile,
locals_to_params,
parse_error,
prepare_record_har_options,
to_impl,
)
from playwright._impl._network import Request, Response, Route, serialize_headers
from playwright._impl._network import (
Request,
Response,
Route,
WebSocketRoute,
WebSocketRouteHandler,
serialize_headers,
)
from playwright._impl._page import BindingCall, Page, Worker
from playwright._impl._str_utils import escape_regex_flags
from playwright._impl._tracing import Tracing
Expand Down Expand Up @@ -106,6 +114,7 @@ def __init__(
self._browser._contexts.append(self)
self._pages: List[Page] = []
self._routes: List[RouteHandler] = []
self._web_socket_routes: List[WebSocketRouteHandler] = []
self._bindings: Dict[str, Any] = {}
self._timeout_settings = TimeoutSettings(None)
self._owner_page: Optional[Page] = None
Expand All @@ -132,7 +141,14 @@ def __init__(
)
),
)

self._channel.on(
"webSocketRoute",
lambda params: self._loop.create_task(
self._on_web_socket_route(
from_channel(params["webSocketRoute"]),
)
),
)
self._channel.on(
"backgroundPage",
lambda params: self._on_background_page(from_channel(params["page"])),
Expand Down Expand Up @@ -248,6 +264,20 @@ async def _on_route(self, route: Route) -> None:
except Exception:
pass

async def _on_web_socket_route(self, web_socket_route: WebSocketRoute) -> None:
route_handler = next(
(
route_handler
for route_handler in self._web_socket_routes
if route_handler.matches(web_socket_route.url)
),
None,
)
if route_handler:
await route_handler.handle(web_socket_route)
else:
web_socket_route.connect_to_server()

def _on_binding(self, binding_call: BindingCall) -> None:
func = self._bindings.get(binding_call._initializer["name"])
if func is None:
Expand Down Expand Up @@ -418,6 +448,17 @@ async def _unroute_internal(
return
await asyncio.gather(*map(lambda router: router.stop(behavior), removed)) # type: ignore

async def route_web_socket(
self, url: URLMatch, handler: WebSocketRouteHandlerCallback
) -> None:
self._web_socket_routes.insert(
0,
WebSocketRouteHandler(
URLMatcher(self._options.get("baseURL"), url), handler
),
)
await self._update_web_socket_interception_patterns()

def _dispose_har_routers(self) -> None:
for router in self._har_routers:
router.dispose()
Expand Down Expand Up @@ -488,6 +529,14 @@ async def _update_interception_patterns(self) -> None:
"setNetworkInterceptionPatterns", {"patterns": patterns}
)

async def _update_web_socket_interception_patterns(self) -> None:
patterns = WebSocketRouteHandler.prepare_interception_patterns(
self._web_socket_routes
)
await self._channel.send(
"setWebSocketInterceptionPatterns", {"patterns": patterns}
)

def expect_event(
self,
event: str,
Expand Down
3 changes: 2 additions & 1 deletion playwright/_impl/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@

if TYPE_CHECKING: # pragma: no cover
from playwright._impl._api_structures import HeadersArray
from playwright._impl._network import Request, Response, Route
from playwright._impl._network import Request, Response, Route, WebSocketRoute

URLMatch = Union[str, Pattern[str], Callable[[str], bool]]
URLMatchRequest = Union[str, Pattern[str], Callable[["Request"], bool]]
URLMatchResponse = Union[str, Pattern[str], Callable[["Response"], bool]]
RouteHandlerCallback = Union[
Callable[["Route"], Any], Callable[["Route", "Request"], Any]
]
WebSocketRouteHandlerCallback = Callable[["WebSocketRoute"], Any]

ColorScheme = Literal["dark", "light", "no-preference", "null"]
ForcedColors = Literal["active", "none", "null"]
Expand Down
218 changes: 217 additions & 1 deletion playwright/_impl/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import json as json_utils
import mimetypes
import re
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
Expand Down Expand Up @@ -46,12 +47,19 @@
)
from playwright._impl._connection import (
ChannelOwner,
Connection,
from_channel,
from_nullable_channel,
)
from playwright._impl._errors import Error
from playwright._impl._event_context_manager import EventContextManagerImpl
from playwright._impl._helper import async_readfile, locals_to_params
from playwright._impl._helper import (
URLMatcher,
WebSocketRouteHandlerCallback,
async_readfile,
locals_to_params,
)
from playwright._impl._str_utils import escape_regex_flags
from playwright._impl._waiter import Waiter

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -548,6 +556,214 @@ async def _race_with_page_close(self, future: Coroutine) -> None:
await asyncio.gather(fut, return_exceptions=True)


class ServerWebSocketRoute:
def __init__(self, ws: "WebSocketRoute"):
self._ws = ws

def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None:
self._ws._on_server_message = handler

def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None:
self._ws._on_server_close = handler

def connect_to_server(self) -> None:
raise NotImplementedError(
"connectToServer must be called on the page-side WebSocketRoute"
)

@property
def url(self) -> str:
return self._ws._initializer["url"]

def close(self, code: int = None, reason: str = None) -> None:
try:
asyncio.create_task(
self._ws._channel.send(
"close",
{
"code": code,
"reason": reason,
},
)
)
except:
pass

def send(self, message: Union[str, bytes]) -> None:
if isinstance(message, str):
asyncio.create_task(
self._ws._channel.send(
"sendToServer", {"message": message, "isBase64": False}
)
)
else:
asyncio.create_task(
self._ws._channel.send(
"sendToServer",
{"message": base64.b64encode(message).decode(), "isBase64": True},
)
)


class WebSocketRoute(ChannelOwner):
def __init__(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> None:
super().__init__(parent, type, guid, initializer)
self._on_page_message: Optional[Callable[[Union[str, bytes]], Any]] = None
self._on_page_close: Optional[
Callable[[Optional[int], Optional[str]], Any]
] = None
self._on_server_message: Optional[Callable[[Union[str, bytes]], Any]] = None
self._on_server_close: Optional[
Callable[[Optional[int], Optional[str]], Any]
] = None
self._server = ServerWebSocketRoute(self)
self._connected = False

self._channel.on("messageFromPage", self._channel_message_from_page)
self._channel.on("messageFromServer", self._channel_message_from_server)
self._channel.on("closePage", self._channel_close_page)
self._channel.on("closeServer", self._channel_close_server)

def _channel_message_from_page(self, event: Dict) -> None:
if self._on_page_message:
self._on_page_message(
base64.b64decode(event["message"])
if event["isBase64"]
else event["message"]
)
elif self._connected:
try:
asyncio.create_task(self._channel.send("sendToServer", event))
except:
pass

def _channel_message_from_server(self, event: Dict) -> None:
if self._on_server_message:
self._on_server_message(
base64.b64decode(event["message"])
if event["isBase64"]
else event["message"]
)
else:
try:
asyncio.create_task(self._channel.send("sendToPage", event))
except:
pass

def _channel_close_page(self, event: Dict) -> None:
if self._on_page_close:
self._on_page_close(event["code"], event["reason"])
else:
try:
asyncio.create_task(self._channel.send("closeServer", event))
except:
pass

def _channel_close_server(self, event: Dict) -> None:
if self._on_server_close:
self._on_server_close(event["code"], event["reason"])
else:
try:
asyncio.create_task(self._channel.send("closePage", event))
except:
pass

@property
def url(self) -> str:
return self._initializer["url"]

async def close(self, code: int = None, reason: str = None) -> None:
try:
await self._channel.send(
"closePage", {"code": code, "reason": reason, "wasClean": True}
)
except:
pass

def connect_to_server(self) -> "WebSocketRoute":
if self._connected:
raise Error("Already connected to the server")
self._connected = True
asyncio.create_task(self._channel.send("connect"))
return cast("WebSocketRoute", self._server)

def send(self, message: Union[str, bytes]) -> None:
if isinstance(message, str):
try:
asyncio.create_task(
self._channel.send(
"sendToPage", {"message": message, "isBase64": False}
)
)
except:
pass
else:
try:
asyncio.create_task(
self._channel.send(
"sendToPage",
{
"message": base64.b64encode(message).decode(),
"isBase64": True,
},
)
)
except:
pass

def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None:
self._on_page_message = handler

def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None:
self._on_page_close = handler

async def _after_handle(self) -> None:
if self._connected:
return
# Ensure that websocket is "open" and can send messages without an actual server connection.
await self._channel.send("ensureOpened")


class WebSocketRouteHandler:
def __init__(self, matcher: URLMatcher, handler: WebSocketRouteHandlerCallback):
self.matcher = matcher
self.handler = handler

@staticmethod
def prepare_interception_patterns(
handlers: List["WebSocketRouteHandler"],
) -> List[dict]:
patterns = []
all_urls = False
for handler in handlers:
if isinstance(handler.matcher.match, str):
patterns.append({"glob": handler.matcher.match})
elif isinstance(handler.matcher._regex_obj, re.Pattern):
patterns.append(
{
"regexSource": handler.matcher._regex_obj.pattern,
"regexFlags": escape_regex_flags(handler.matcher._regex_obj),
}
)
else:
all_urls = True

if all_urls:
return [{"glob": "**/*"}]
return patterns

def matches(self, ws_url: str) -> bool:
return self.matcher.matches(ws_url)

async def handle(self, websocket_route: "WebSocketRoute") -> None:
maybe_future = self.handler(websocket_route)
if maybe_future:
breakpoint()
await websocket_route._after_handle()


class Response(ChannelOwner):
def __init__(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
Expand Down
10 changes: 9 additions & 1 deletion playwright/_impl/_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from playwright._impl._frame import Frame
from playwright._impl._js_handle import JSHandle
from playwright._impl._local_utils import LocalUtils
from playwright._impl._network import Request, Response, Route, WebSocket
from playwright._impl._network import (
Request,
Response,
Route,
WebSocket,
WebSocketRoute,
)
from playwright._impl._page import BindingCall, Page, Worker
from playwright._impl._playwright import Playwright
from playwright._impl._selectors import SelectorsOwner
Expand Down Expand Up @@ -88,6 +94,8 @@ def create_remote_object(
return Tracing(parent, type, guid, initializer)
if type == "WebSocket":
return WebSocket(parent, type, guid, initializer)
if type == "WebSocketRoute":
return WebSocketRoute(parent, type, guid, initializer)
if type == "Worker":
return Worker(parent, type, guid, initializer)
if type == "WritableStream":
Expand Down
Loading

0 comments on commit 3ed5b5c

Please sign in to comment.