Skip to content

Commit

Permalink
test: move SyncBase._gather into tests (#2316)
Browse files Browse the repository at this point in the history
  • Loading branch information
mxschmitt authored Feb 22, 2024
1 parent c4ffd45 commit 3e44464
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 38 deletions.
34 changes: 0 additions & 34 deletions playwright/_impl/_sync_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
Any,
Callable,
Coroutine,
Dict,
Generator,
Generic,
List,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -133,38 +131,6 @@ def remove_listener(self, event: Any, f: Any) -> None:
"""Removes the function ``f`` from ``event``."""
self._impl_obj.remove_listener(event, self._wrap_handler(f))

def _gather(self, *actions: Callable) -> List[Any]:
g_self = greenlet.getcurrent()
results: Dict[Callable, Any] = {}
exceptions: List[Exception] = []

def action_wrapper(action: Callable) -> Callable:
def body() -> Any:
try:
results[action] = action()
except Exception as e:
results[action] = e
exceptions.append(e)
g_self.switch()

return body

async def task() -> None:
for action in actions:
g = greenlet.greenlet(action_wrapper(action))
g.switch()

self._loop.create_task(task())

while len(results) < len(actions):
self._dispatcher_fiber.switch()

asyncio._set_running_loop(self._loop)
if exceptions:
raise exceptions[0]

return list(map(lambda action: results[action], actions))


class SyncContextManager(SyncBase):
def __enter__(self: Self) -> Self:
Expand Down
40 changes: 39 additions & 1 deletion tests/sync/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.


from typing import Dict, Generator
import asyncio
from typing import Any, Callable, Dict, Generator, List

import pytest
from greenlet import greenlet

from playwright.sync_api import (
Browser,
Expand Down Expand Up @@ -83,3 +85,39 @@ def page(context: BrowserContext) -> Generator[Page, None, None]:
@pytest.fixture(scope="session")
def selectors(playwright: Playwright) -> Selectors:
return playwright.selectors


@pytest.fixture(scope="session")
def sync_gather(playwright: Playwright) -> Generator[Callable, None, None]:
def _sync_gather_impl(*actions: Callable) -> List[Any]:
g_self = greenlet.getcurrent()
results: Dict[Callable, Any] = {}
exceptions: List[Exception] = []

def action_wrapper(action: Callable) -> Callable:
def body() -> Any:
try:
results[action] = action()
except Exception as e:
results[action] = e
exceptions.append(e)
g_self.switch()

return body

async def task() -> None:
for action in actions:
g = greenlet(action_wrapper(action))
g.switch()

asyncio.create_task(task())

while len(results) < len(actions):
playwright._dispatcher_fiber.switch()

if exceptions:
raise exceptions[0]

return list(map(lambda action: results[action], actions))

yield _sync_gather_impl
8 changes: 5 additions & 3 deletions tests/sync/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import multiprocessing
import os
from typing import Any, Dict
from typing import Any, Callable, Dict

import pytest

Expand Down Expand Up @@ -266,10 +266,12 @@ def test_sync_set_default_timeout(page: Page) -> None:
assert "Timeout 1ms exceeded." in exc.value.message


def test_close_should_reject_all_promises(context: BrowserContext) -> None:
def test_close_should_reject_all_promises(
context: BrowserContext, sync_gather: Callable
) -> None:
new_page = context.new_page()
with pytest.raises(Error) as exc_info:
new_page._gather(
sync_gather(
lambda: new_page.evaluate("() => new Promise(r => {})"),
lambda: new_page.close(),
)
Expand Down

0 comments on commit 3e44464

Please sign in to comment.