Skip to content

Commit

Permalink
fix: inherit expect context managers from contextlib (#2370)
Browse files Browse the repository at this point in the history
  • Loading branch information
mxschmitt authored Mar 22, 2024
1 parent 09f529a commit 5a4779e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
11 changes: 6 additions & 5 deletions playwright/_impl/_async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import asyncio
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Any, Callable, Generic, Type, TypeVar
from typing import Any, Callable, Generic, Optional, Type, TypeVar

from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper

Expand All @@ -40,7 +41,7 @@ def is_done(self) -> bool:
return self._future.done()


class AsyncEventContextManager(Generic[T]):
class AsyncEventContextManager(Generic[T], AbstractAsyncContextManager):
def __init__(self, future: "asyncio.Future[T]") -> None:
self._event = AsyncEventInfo[T](future)

Expand All @@ -49,9 +50,9 @@ async def __aenter__(self) -> AsyncEventInfo[T]:

async def __aexit__(
self,
exc_type: Type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc_val:
self._event._cancel()
Expand Down
10 changes: 6 additions & 4 deletions playwright/_impl/_sync_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import asyncio
import inspect
import traceback
from contextlib import AbstractContextManager
from types import TracebackType
from typing import (
Any,
Callable,
Coroutine,
Generator,
Generic,
Optional,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -64,7 +66,7 @@ def is_done(self) -> bool:
return self._future.done()


class EventContextManager(Generic[T]):
class EventContextManager(Generic[T], AbstractContextManager):
def __init__(self, sync_base: "SyncBase", future: "asyncio.Future[T]") -> None:
self._event = EventInfo[T](sync_base, future)

Expand All @@ -73,9 +75,9 @@ def __enter__(self) -> EventInfo[T]:

def __exit__(
self,
exc_type: Type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc_val:
self._event._cancel()
Expand Down

0 comments on commit 5a4779e

Please sign in to comment.