Skip to content

Commit

Permalink
chore: collect stale handles from the server side
Browse files Browse the repository at this point in the history
  • Loading branch information
mxschmitt committed Oct 12, 2023
1 parent 2886f00 commit 6f129e0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
35 changes: 21 additions & 14 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@


class Channel(AsyncIOEventEmitter):
def __init__(self, connection: "Connection", guid: str) -> None:
def __init__(self, connection: "Connection", object: "ChannelOwner") -> None:
super().__init__()
self._connection: Connection = connection
self._guid = guid
self._object: Optional[ChannelOwner] = None
self._object = object
self._connection = connection

async def send(self, method: str, params: Dict = None) -> Any:
return await self._connection.wrap_api_call(
Expand All @@ -71,7 +70,7 @@ def send_no_reply(self, method: str, params: Dict = None) -> None:
# No reply messages are used to e.g. waitForEventInfo(after).
self._connection.wrap_api_call_sync(
lambda: self._connection._send_message_to_server(
self._guid, method, {} if params is None else params, True
self._object, method, {} if params is None else params, True
)
)

Expand All @@ -80,7 +79,9 @@ async def inner_send(
) -> Any:
if params is None:
params = {}
callback = self._connection._send_message_to_server(self._guid, method, params)
callback = self._connection._send_message_to_server(
self._object, method, params
)
if self._connection._error:
error = self._connection._error
self._connection._error = None
Expand Down Expand Up @@ -129,25 +130,26 @@ def __init__(
parent if isinstance(parent, ChannelOwner) else None
)
self._objects: Dict[str, "ChannelOwner"] = {}
self._channel: Channel = Channel(self._connection, guid)
self._channel._object = self
self._channel = Channel(self._connection, self)
self._initializer = initializer
self._was_collected = False

self._connection._objects[guid] = self
if self._parent:
self._parent._objects[guid] = self

self._event_to_subscription_mapping: Dict[str, str] = {}

def _dispose(self) -> None:
def _dispose(self, reason: Optional[str]) -> None:
# Clean up from parent and connection.
if self._parent:
del self._parent._objects[self._guid]
del self._connection._objects[self._guid]
self._was_collected = reason == "gc"

# Dispose all children.
for object in list(self._objects.values()):
object._dispose()
object._dispose(reason)
self._objects.clear()

def _adopt(self, child: "ChannelOwner") -> None:
Expand Down Expand Up @@ -308,10 +310,14 @@ def set_in_tracing(self, is_tracing: bool) -> None:
self._tracing_count -= 1

def _send_message_to_server(
self, guid: str, method: str, params: Dict, no_reply: bool = False
self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False
) -> ProtocolCallback:
if self._closed_error_message:
raise Error(self._closed_error_message)
if object._was_collected:
raise Error(
"The object has been collected to prevent unbounded heap growth."
)
self._last_id += 1
id = self._last_id
callback = ProtocolCallback(self._loop)
Expand All @@ -335,7 +341,7 @@ def _send_message_to_server(
)
message = {
"id": id,
"guid": guid,
"guid": object._guid,
"method": method,
"params": self._replace_channels_with_guids(params),
"metadata": {
Expand All @@ -345,7 +351,7 @@ def _send_message_to_server(
"internal": not stack_trace_information["apiName"],
},
}
if self._tracing_count > 0 and frames and guid != "localUtils":
if self._tracing_count > 0 and frames and object._guid != "localUtils":
self.local_utils.add_stack_to_tracing_no_reply(id, frames)

self._transport.send(message)
Expand Down Expand Up @@ -401,7 +407,8 @@ def dispatch(self, msg: ParsedMessagePayload) -> None:
return

if method == "__dispose__":
self._objects[guid]._dispose()
assert params
self._objects[guid]._dispose(cast(Optional[str], params.get("reason")))
return
object = self._objects[guid]
should_replace_guids_with_channels = "jsonPipe@" not in guid
Expand Down
2 changes: 1 addition & 1 deletion playwright/_impl/_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self._main_frame._page = self
self._frames = [self._main_frame]
self._viewport_size: Optional[ViewportSize] = initializer.get("viewportSize")
self._is_closed = False
self._is_closed: bool = False
self._workers: List["Worker"] = []
self._bindings: Dict[str, Any] = {}
self._routes: List[RouteHandler] = []
Expand Down
19 changes: 18 additions & 1 deletion tests/async/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from playwright.async_api import async_playwright
from playwright.async_api import Page, async_playwright

from ..server import Server

Expand Down Expand Up @@ -67,3 +67,20 @@ async def test_cancel_pending_protocol_call_on_playwright_stop(server: Server) -
with pytest.raises(Exception) as exc_info:
await pending_task
assert "Connection closed" in str(exc_info.value)


async def test_should_collect_stale_handles(page: Page, server: Server) -> None:
page.on("request", lambda: None)
response = await page.goto(server.PREFIX + "/title.html")
for i in range(1000):
await page.evaluate(
"""async () => {
const response = await fetch('/');
await response.text();
}"""
)
with pytest.raises(Exception) as exc_info:
await response.all_headers()
assert "The object has been collected to prevent unbounded heap growth." in str(
exc_info.value
)
18 changes: 18 additions & 0 deletions tests/sync/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,21 @@ def test_call_sync_method_after_playwright_close_with_own_loop(
p.start()
p.join()
assert p.exitcode == 0


def test_should_collect_stale_handles(page: Page, server: Server) -> None:
page.on("request", lambda request: None)
response = page.goto(server.PREFIX + "/title.html")
assert response
for i in range(1000):
page.evaluate(
"""async () => {
const response = await fetch('/');
await response.text();
}"""
)
with pytest.raises(Exception) as exc_info:
response.all_headers()
assert "The object has been collected to prevent unbounded heap growth." in str(
exc_info.value
)

0 comments on commit 6f129e0

Please sign in to comment.