diff --git a/enochecker3/enochecker.py b/enochecker3/enochecker.py index 3eaec4f..a42ad14 100644 --- a/enochecker3/enochecker.py +++ b/enochecker3/enochecker.py @@ -4,7 +4,7 @@ import os import sys import traceback -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import AsyncExitStack from inspect import Parameter, isawaitable, signature from types import TracebackType from typing import ( @@ -101,8 +101,8 @@ async def __aexit__( async def get(self, t: type, name: str = "") -> Any: injector = self.checker.resolve_injector(name, t) - args = await self._exit_stack.enter_async_context( - self.checker._inject_dependencies(self.task, injector, None) + args = await self.checker._inject_dependencies( + self.task, injector, self._exit_stack ) res = injector(*args) if isawaitable(res): @@ -243,13 +243,13 @@ def resolve_injector(self, name: str, t: type) -> Callable[..., Any]: return self._dependency_injections[generic_key] return self._dependency_injections[key] - @asynccontextmanager async def _inject_dependencies( self, task: BaseCheckerTaskMessage, f: Callable[..., Any], + stack: AsyncExitStack, dependencies: Optional[Set[Callable[..., Any]]] = None, - ) -> AsyncIterator[Any]: + ) -> List[Any]: dependencies = dependencies or set() sig = signature(f) @@ -271,23 +271,22 @@ async def _inject_dependencies( f"Detected circular dependency in {f} with injected type {v.annotation}" ) else: - async with self._inject_dependencies( - task, injector, dependencies.union([injector]) - ) as args_: - arg = injector(*args_) - if isawaitable(arg): - arg = await arg - args.append(arg) - - async with AsyncExitStack() as stack: - # new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:" - new_args = [] - for arg in args: - if not hasattr(arg, "__enter__") and not hasattr(arg, "__aenter__"): - new_args.append(arg) - continue - new_args.append(await stack.enter_async_context(arg)) - yield new_args + args_ = await self._inject_dependencies( + task, injector, stack, dependencies.union([injector]) + ) + arg = injector(*args_) + if isawaitable(arg): + arg = await arg + args.append(arg) + + # new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:" + new_args = [] + for arg in args: + if not hasattr(arg, "__enter__") and not hasattr(arg, "__aenter__"): + new_args.append(arg) + continue + new_args.append(await stack.enter_async_context(arg)) + return new_args async def _call_method_raw(self, task: BaseCheckerTaskMessage) -> Optional[str]: variant_id = task.variant_id @@ -299,8 +298,10 @@ async def _call_method_raw(self, task: BaseCheckerTaskMessage) -> Optional[str]: f"Variant_id {variant_id} not defined for method {method}" ) - async with self._inject_dependencies(task, f) as args: - return await f(*args) + async with AsyncExitStack() as stack: + args = await self._inject_dependencies(task, f, stack) + res = await f(*args) + return res async def _call_method(self, task: BaseCheckerTaskMessage) -> Optional[str]: try: diff --git a/setup.py b/setup.py index 538d7ed..4d075d0 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="enochecker3", - version="0.8.0", + version="0.8.1", author="ldruschk", author_email="ldruschk@posteo.de", description="FastAPI based library for building async python checkers for the EnoEngine A/D CTF Framework",