Skip to content

Commit

Permalink
fix bug where async context of recursive dependencies would be exited…
Browse files Browse the repository at this point in the history
… before the beginning of the function (#35)

* fix bug where async context of recursive dependencies would be exited before the beginning of the function

* bump version to 0.8.1
  • Loading branch information
ldruschk authored May 14, 2023
1 parent 7af0069 commit 282eaac
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
49 changes: 25 additions & 24 deletions enochecker3/enochecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
import traceback
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import AsyncExitStack
from inspect import Parameter, isawaitable, signature
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -101,8 +101,8 @@ async def __aexit__(

async def get(self, t: type, name: str = "") -> Any:
injector = self.checker.resolve_injector(name, t)
args = await self._exit_stack.enter_async_context(
self.checker._inject_dependencies(self.task, injector, None)
args = await self.checker._inject_dependencies(
self.task, injector, self._exit_stack
)
res = injector(*args)
if isawaitable(res):
Expand Down Expand Up @@ -243,13 +243,13 @@ def resolve_injector(self, name: str, t: type) -> Callable[..., Any]:
return self._dependency_injections[generic_key]
return self._dependency_injections[key]

@asynccontextmanager
async def _inject_dependencies(
self,
task: BaseCheckerTaskMessage,
f: Callable[..., Any],
stack: AsyncExitStack,
dependencies: Optional[Set[Callable[..., Any]]] = None,
) -> AsyncIterator[Any]:
) -> List[Any]:
dependencies = dependencies or set()

sig = signature(f)
Expand All @@ -271,23 +271,22 @@ async def _inject_dependencies(
f"Detected circular dependency in {f} with injected type {v.annotation}"
)
else:
async with self._inject_dependencies(
task, injector, dependencies.union([injector])
) as args_:
arg = injector(*args_)
if isawaitable(arg):
arg = await arg
args.append(arg)

async with AsyncExitStack() as stack:
# new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:"
new_args = []
for arg in args:
if not hasattr(arg, "__enter__") and not hasattr(arg, "__aenter__"):
new_args.append(arg)
continue
new_args.append(await stack.enter_async_context(arg))
yield new_args
args_ = await self._inject_dependencies(
task, injector, stack, dependencies.union([injector])
)
arg = injector(*args_)
if isawaitable(arg):
arg = await arg
args.append(arg)

# new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:"
new_args = []
for arg in args:
if not hasattr(arg, "__enter__") and not hasattr(arg, "__aenter__"):
new_args.append(arg)
continue
new_args.append(await stack.enter_async_context(arg))
return new_args

async def _call_method_raw(self, task: BaseCheckerTaskMessage) -> Optional[str]:
variant_id = task.variant_id
Expand All @@ -299,8 +298,10 @@ async def _call_method_raw(self, task: BaseCheckerTaskMessage) -> Optional[str]:
f"Variant_id {variant_id} not defined for method {method}"
)

async with self._inject_dependencies(task, f) as args:
return await f(*args)
async with AsyncExitStack() as stack:
args = await self._inject_dependencies(task, f, stack)
res = await f(*args)
return res

async def _call_method(self, task: BaseCheckerTaskMessage) -> Optional[str]:
try:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="enochecker3",
version="0.8.0",
version="0.8.1",
author="ldruschk",
author_email="[email protected]",
description="FastAPI based library for building async python checkers for the EnoEngine A/D CTF Framework",
Expand Down

0 comments on commit 282eaac

Please sign in to comment.