From 6f129e0b9aa25ab5121299e49af19a702e88e87a Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Thu, 12 Oct 2023 09:22:09 +0200 Subject: [PATCH] chore: collect stale handles from the server side --- playwright/_impl/_connection.py | 35 ++++++++++++++++++++------------- playwright/_impl/_page.py | 2 +- tests/async/test_asyncio.py | 19 +++++++++++++++++- tests/sync/test_sync.py | 18 +++++++++++++++++ 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/playwright/_impl/_connection.py b/playwright/_impl/_connection.py index e11612fcf3..61a678fdd9 100644 --- a/playwright/_impl/_connection.py +++ b/playwright/_impl/_connection.py @@ -51,11 +51,10 @@ class Channel(AsyncIOEventEmitter): - def __init__(self, connection: "Connection", guid: str) -> None: + def __init__(self, connection: "Connection", object: "ChannelOwner") -> None: super().__init__() - self._connection: Connection = connection - self._guid = guid - self._object: Optional[ChannelOwner] = None + self._object = object + self._connection = connection async def send(self, method: str, params: Dict = None) -> Any: return await self._connection.wrap_api_call( @@ -71,7 +70,7 @@ def send_no_reply(self, method: str, params: Dict = None) -> None: # No reply messages are used to e.g. waitForEventInfo(after). self._connection.wrap_api_call_sync( lambda: self._connection._send_message_to_server( - self._guid, method, {} if params is None else params, True + self._object, method, {} if params is None else params, True ) ) @@ -80,7 +79,9 @@ async def inner_send( ) -> Any: if params is None: params = {} - callback = self._connection._send_message_to_server(self._guid, method, params) + callback = self._connection._send_message_to_server( + self._object, method, params + ) if self._connection._error: error = self._connection._error self._connection._error = None @@ -129,9 +130,9 @@ def __init__( parent if isinstance(parent, ChannelOwner) else None ) self._objects: Dict[str, "ChannelOwner"] = {} - self._channel: Channel = Channel(self._connection, guid) - self._channel._object = self + self._channel = Channel(self._connection, self) self._initializer = initializer + self._was_collected = False self._connection._objects[guid] = self if self._parent: @@ -139,15 +140,16 @@ def __init__( self._event_to_subscription_mapping: Dict[str, str] = {} - def _dispose(self) -> None: + def _dispose(self, reason: Optional[str]) -> None: # Clean up from parent and connection. if self._parent: del self._parent._objects[self._guid] del self._connection._objects[self._guid] + self._was_collected = reason == "gc" # Dispose all children. for object in list(self._objects.values()): - object._dispose() + object._dispose(reason) self._objects.clear() def _adopt(self, child: "ChannelOwner") -> None: @@ -308,10 +310,14 @@ def set_in_tracing(self, is_tracing: bool) -> None: self._tracing_count -= 1 def _send_message_to_server( - self, guid: str, method: str, params: Dict, no_reply: bool = False + self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False ) -> ProtocolCallback: if self._closed_error_message: raise Error(self._closed_error_message) + if object._was_collected: + raise Error( + "The object has been collected to prevent unbounded heap growth." + ) self._last_id += 1 id = self._last_id callback = ProtocolCallback(self._loop) @@ -335,7 +341,7 @@ def _send_message_to_server( ) message = { "id": id, - "guid": guid, + "guid": object._guid, "method": method, "params": self._replace_channels_with_guids(params), "metadata": { @@ -345,7 +351,7 @@ def _send_message_to_server( "internal": not stack_trace_information["apiName"], }, } - if self._tracing_count > 0 and frames and guid != "localUtils": + if self._tracing_count > 0 and frames and object._guid != "localUtils": self.local_utils.add_stack_to_tracing_no_reply(id, frames) self._transport.send(message) @@ -401,7 +407,8 @@ def dispatch(self, msg: ParsedMessagePayload) -> None: return if method == "__dispose__": - self._objects[guid]._dispose() + assert params + self._objects[guid]._dispose(cast(Optional[str], params.get("reason"))) return object = self._objects[guid] should_replace_guids_with_channels = "jsonPipe@" not in guid diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index be2538689f..f591834017 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -141,7 +141,7 @@ def __init__( self._main_frame._page = self self._frames = [self._main_frame] self._viewport_size: Optional[ViewportSize] = initializer.get("viewportSize") - self._is_closed = False + self._is_closed: bool = False self._workers: List["Worker"] = [] self._bindings: Dict[str, Any] = {} self._routes: List[RouteHandler] = [] diff --git a/tests/async/test_asyncio.py b/tests/async/test_asyncio.py index 4d6174d1ba..6bdf2456b7 100644 --- a/tests/async/test_asyncio.py +++ b/tests/async/test_asyncio.py @@ -17,7 +17,7 @@ import pytest -from playwright.async_api import async_playwright +from playwright.async_api import Page, async_playwright from ..server import Server @@ -67,3 +67,20 @@ async def test_cancel_pending_protocol_call_on_playwright_stop(server: Server) - with pytest.raises(Exception) as exc_info: await pending_task assert "Connection closed" in str(exc_info.value) + + +async def test_should_collect_stale_handles(page: Page, server: Server) -> None: + page.on("request", lambda: None) + response = await page.goto(server.PREFIX + "/title.html") + for i in range(1000): + await page.evaluate( + """async () => { + const response = await fetch('/'); + await response.text(); + }""" + ) + with pytest.raises(Exception) as exc_info: + await response.all_headers() + assert "The object has been collected to prevent unbounded heap growth." in str( + exc_info.value + ) diff --git a/tests/sync/test_sync.py b/tests/sync/test_sync.py index 11f6aab08c..af148afe3c 100644 --- a/tests/sync/test_sync.py +++ b/tests/sync/test_sync.py @@ -333,3 +333,21 @@ def test_call_sync_method_after_playwright_close_with_own_loop( p.start() p.join() assert p.exitcode == 0 + + +def test_should_collect_stale_handles(page: Page, server: Server) -> None: + page.on("request", lambda request: None) + response = page.goto(server.PREFIX + "/title.html") + assert response + for i in range(1000): + page.evaluate( + """async () => { + const response = await fetch('/'); + await response.text(); + }""" + ) + with pytest.raises(Exception) as exc_info: + response.all_headers() + assert "The object has been collected to prevent unbounded heap growth." in str( + exc_info.value + )