Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: collect stale handles from the server side #2111

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@


class Channel(AsyncIOEventEmitter):
def __init__(self, connection: "Connection", guid: str) -> None:
def __init__(self, connection: "Connection", object: "ChannelOwner") -> None:
super().__init__()
self._connection: Connection = connection
self._guid = guid
self._object: Optional[ChannelOwner] = None
self._connection = connection
self._guid = object._guid
self._object = object

async def send(self, method: str, params: Dict = None) -> Any:
return await self._connection.wrap_api_call(
Expand All @@ -71,7 +71,7 @@ def send_no_reply(self, method: str, params: Dict = None) -> None:
# No reply messages are used to e.g. waitForEventInfo(after).
self._connection.wrap_api_call_sync(
lambda: self._connection._send_message_to_server(
self._guid, method, {} if params is None else params, True
self._object, method, {} if params is None else params, True
)
)

Expand All @@ -80,7 +80,9 @@ async def inner_send(
) -> Any:
if params is None:
params = {}
callback = self._connection._send_message_to_server(self._guid, method, params)
callback = self._connection._send_message_to_server(
self._object, method, params
)
if self._connection._error:
error = self._connection._error
self._connection._error = None
Expand Down Expand Up @@ -121,33 +123,34 @@ def __init__(
self._loop: asyncio.AbstractEventLoop = parent._loop
self._dispatcher_fiber: Any = parent._dispatcher_fiber
self._type = type
self._guid = guid
self._guid: str = guid
self._connection: Connection = (
parent._connection if isinstance(parent, ChannelOwner) else parent
)
self._parent: Optional[ChannelOwner] = (
parent if isinstance(parent, ChannelOwner) else None
)
self._objects: Dict[str, "ChannelOwner"] = {}
self._channel: Channel = Channel(self._connection, guid)
self._channel._object = self
self._channel: Channel = Channel(self._connection, self)
self._initializer = initializer
self._was_collected = False

self._connection._objects[guid] = self
if self._parent:
self._parent._objects[guid] = self

self._event_to_subscription_mapping: Dict[str, str] = {}

def _dispose(self) -> None:
def _dispose(self, reason: Optional[str]) -> None:
# Clean up from parent and connection.
if self._parent:
del self._parent._objects[self._guid]
del self._connection._objects[self._guid]
self._was_collected = reason == "gc"

# Dispose all children.
for object in list(self._objects.values()):
object._dispose()
object._dispose(reason)
self._objects.clear()

def _adopt(self, child: "ChannelOwner") -> None:
Expand Down Expand Up @@ -308,10 +311,14 @@ def set_in_tracing(self, is_tracing: bool) -> None:
self._tracing_count -= 1

def _send_message_to_server(
self, guid: str, method: str, params: Dict, no_reply: bool = False
self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False
) -> ProtocolCallback:
if self._closed_error_message:
raise Error(self._closed_error_message)
if object._was_collected:
raise Error(
"The object has been collected to prevent unbounded heap growth."
)
self._last_id += 1
id = self._last_id
callback = ProtocolCallback(self._loop)
Expand All @@ -335,7 +342,7 @@ def _send_message_to_server(
)
message = {
"id": id,
"guid": guid,
"guid": object._guid,
"method": method,
"params": self._replace_channels_with_guids(params),
"metadata": {
Expand All @@ -345,7 +352,7 @@ def _send_message_to_server(
"internal": not stack_trace_information["apiName"],
},
}
if self._tracing_count > 0 and frames and guid != "localUtils":
if self._tracing_count > 0 and frames and object._guid != "localUtils":
self.local_utils.add_stack_to_tracing_no_reply(id, frames)

self._transport.send(message)
Expand Down Expand Up @@ -401,7 +408,8 @@ def dispatch(self, msg: ParsedMessagePayload) -> None:
return

if method == "__dispose__":
self._objects[guid]._dispose()
assert isinstance(params, dict)
self._objects[guid]._dispose(cast(Optional[str], params.get("reason")))
return
object = self._objects[guid]
should_replace_guids_with_channels = "jsonPipe@" not in guid
Expand Down
19 changes: 18 additions & 1 deletion tests/async/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from playwright.async_api import async_playwright
from playwright.async_api import Page, async_playwright

from ..server import Server

Expand Down Expand Up @@ -67,3 +67,20 @@ async def test_cancel_pending_protocol_call_on_playwright_stop(server: Server) -
with pytest.raises(Exception) as exc_info:
await pending_task
assert "Connection closed" in str(exc_info.value)


async def test_should_collect_stale_handles(page: Page, server: Server) -> None:
page.on("request", lambda: None)
response = await page.goto(server.PREFIX + "/title.html")
for i in range(1000):
await page.evaluate(
"""async () => {
const response = await fetch('/');
await response.text();
}"""
)
with pytest.raises(Exception) as exc_info:
await response.all_headers()
assert "The object has been collected to prevent unbounded heap growth." in str(
exc_info.value
)
18 changes: 18 additions & 0 deletions tests/sync/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,21 @@ def test_call_sync_method_after_playwright_close_with_own_loop(
p.start()
p.join()
assert p.exitcode == 0


def test_should_collect_stale_handles(page: Page, server: Server) -> None:
page.on("request", lambda request: None)
response = page.goto(server.PREFIX + "/title.html")
assert response
for i in range(1000):
page.evaluate(
"""async () => {
const response = await fetch('/');
await response.text();
}"""
)
with pytest.raises(Exception) as exc_info:
response.all_headers()
assert "The object has been collected to prevent unbounded heap growth." in str(
exc_info.value
)
Loading