diff --git a/playwright/_impl/_sync_base.py b/playwright/_impl/_sync_base.py index 0eefbba13..50a2c647e 100644 --- a/playwright/_impl/_sync_base.py +++ b/playwright/_impl/_sync_base.py @@ -20,10 +20,8 @@ Any, Callable, Coroutine, - Dict, Generator, Generic, - List, Type, TypeVar, Union, @@ -133,38 +131,6 @@ def remove_listener(self, event: Any, f: Any) -> None: """Removes the function ``f`` from ``event``.""" self._impl_obj.remove_listener(event, self._wrap_handler(f)) - def _gather(self, *actions: Callable) -> List[Any]: - g_self = greenlet.getcurrent() - results: Dict[Callable, Any] = {} - exceptions: List[Exception] = [] - - def action_wrapper(action: Callable) -> Callable: - def body() -> Any: - try: - results[action] = action() - except Exception as e: - results[action] = e - exceptions.append(e) - g_self.switch() - - return body - - async def task() -> None: - for action in actions: - g = greenlet.greenlet(action_wrapper(action)) - g.switch() - - self._loop.create_task(task()) - - while len(results) < len(actions): - self._dispatcher_fiber.switch() - - asyncio._set_running_loop(self._loop) - if exceptions: - raise exceptions[0] - - return list(map(lambda action: results[action], actions)) - class SyncContextManager(SyncBase): def __enter__(self: Self) -> Self: diff --git a/tests/sync/conftest.py b/tests/sync/conftest.py index 075cccde2..68221b216 100644 --- a/tests/sync/conftest.py +++ b/tests/sync/conftest.py @@ -13,9 +13,11 @@ # limitations under the License. -from typing import Dict, Generator +import asyncio +from typing import Any, Callable, Dict, Generator, List import pytest +from greenlet import greenlet from playwright.sync_api import ( Browser, @@ -83,3 +85,39 @@ def page(context: BrowserContext) -> Generator[Page, None, None]: @pytest.fixture(scope="session") def selectors(playwright: Playwright) -> Selectors: return playwright.selectors + + +@pytest.fixture(scope="session") +def sync_gather(playwright: Playwright) -> Generator[Callable, None, None]: + def _sync_gather_impl(*actions: Callable) -> List[Any]: + g_self = greenlet.getcurrent() + results: Dict[Callable, Any] = {} + exceptions: List[Exception] = [] + + def action_wrapper(action: Callable) -> Callable: + def body() -> Any: + try: + results[action] = action() + except Exception as e: + results[action] = e + exceptions.append(e) + g_self.switch() + + return body + + async def task() -> None: + for action in actions: + g = greenlet(action_wrapper(action)) + g.switch() + + asyncio.create_task(task()) + + while len(results) < len(actions): + playwright._dispatcher_fiber.switch() + + if exceptions: + raise exceptions[0] + + return list(map(lambda action: results[action], actions)) + + yield _sync_gather_impl diff --git a/tests/sync/test_sync.py b/tests/sync/test_sync.py index fbd94b932..64eace1e9 100644 --- a/tests/sync/test_sync.py +++ b/tests/sync/test_sync.py @@ -14,7 +14,7 @@ import multiprocessing import os -from typing import Any, Dict +from typing import Any, Callable, Dict import pytest @@ -266,10 +266,12 @@ def test_sync_set_default_timeout(page: Page) -> None: assert "Timeout 1ms exceeded." in exc.value.message -def test_close_should_reject_all_promises(context: BrowserContext) -> None: +def test_close_should_reject_all_promises( + context: BrowserContext, sync_gather: Callable +) -> None: new_page = context.new_page() with pytest.raises(Error) as exc_info: - new_page._gather( + sync_gather( lambda: new_page.evaluate("() => new Promise(r => {})"), lambda: new_page.close(), )