From 9de6993b7836dae6aebb82a4b93f6ede12205349 Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Fri, 22 Mar 2024 13:50:40 +0100 Subject: [PATCH] fix: inherit expect context managers from contextlib --- playwright/_impl/_async_base.py | 11 ++++++----- playwright/_impl/_sync_base.py | 10 ++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/playwright/_impl/_async_base.py b/playwright/_impl/_async_base.py index e497232d8..e9544b733 100644 --- a/playwright/_impl/_async_base.py +++ b/playwright/_impl/_async_base.py @@ -13,8 +13,9 @@ # limitations under the License. import asyncio +from contextlib import AbstractAsyncContextManager from types import TracebackType -from typing import Any, Callable, Generic, Type, TypeVar +from typing import Any, Callable, Generic, Optional, Type, TypeVar from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper @@ -40,7 +41,7 @@ def is_done(self) -> bool: return self._future.done() -class AsyncEventContextManager(Generic[T]): +class AsyncEventContextManager(Generic[T], AbstractAsyncContextManager): def __init__(self, future: "asyncio.Future[T]") -> None: self._event = AsyncEventInfo[T](future) @@ -49,9 +50,9 @@ async def __aenter__(self) -> AsyncEventInfo[T]: async def __aexit__( self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: if exc_val: self._event._cancel() diff --git a/playwright/_impl/_sync_base.py b/playwright/_impl/_sync_base.py index 50a2c647e..f07b947b2 100644 --- a/playwright/_impl/_sync_base.py +++ b/playwright/_impl/_sync_base.py @@ -15,6 +15,7 @@ import asyncio import inspect import traceback +from contextlib import AbstractContextManager from types import TracebackType from typing import ( Any, @@ -22,6 +23,7 @@ Coroutine, Generator, Generic, + Optional, Type, TypeVar, Union, @@ -64,7 +66,7 @@ def is_done(self) -> bool: return self._future.done() -class EventContextManager(Generic[T]): +class EventContextManager(Generic[T], AbstractContextManager): def __init__(self, sync_base: "SyncBase", future: "asyncio.Future[T]") -> None: self._event = EventInfo[T](sync_base, future) @@ -73,9 +75,9 @@ def __enter__(self) -> EventInfo[T]: def __exit__( self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: if exc_val: self._event._cancel()