Skip to content

Commit

Permalink
Merge pull request #440 from supertokens/fix/async-lib-not-found-err
Browse files Browse the repository at this point in the history
fix: Async lib not found error
  • Loading branch information
rishabhpoddar authored Sep 6, 2023
2 parents dbca4da + d76e8fb commit 7c74be5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 35 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

- Retry Querier request on `AsyncLibraryNotFoundError`

## [0.14.10] - 2023-09-31

- Uses nest_asyncio patch in event loop - sync to async
- Uses `nest_asyncio` patch in event loop - sync to async

## [0.14.9] - 2023-09-28

Expand Down
11 changes: 6 additions & 5 deletions supertokens_python/async_to_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
_T = TypeVar("_T")


def check_event_loop():
def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
try:
asyncio.get_event_loop()
except RuntimeError as ex:
return asyncio.get_event_loop()
except Exception as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()
nest_asyncio.apply(loop) # type: ignore
asyncio.set_event_loop(loop)
return loop
raise ex


def sync(co: Coroutine[Any, Any, _T]) -> _T:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
return loop.run_until_complete(co)
87 changes: 61 additions & 26 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from .exceptions import raise_general_exception
from .process_state import AllowedProcessStates, ProcessState
from .utils import find_max_version, is_4xx_error, is_5xx_error
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
from sniffio import AsyncLibraryNotFoundError


class Querier:
Expand Down Expand Up @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
raise_general_exception("calling testing function in non testing env")
return Querier.__hosts_alive_for_testing

async def api_request(
self,
url: str,
method: str,
attempts_remaining: int,
*args: Any,
**kwargs: Any,
) -> Response:
if attempts_remaining == 0:
raise_general_exception("Retry request failed")

try:
async with AsyncClient() as client:
if method == "GET":
return await client.get(url, *args, **kwargs) # type: ignore
if method == "POST":
return await client.post(url, *args, **kwargs) # type: ignore
if method == "PUT":
return await client.put(url, *args, **kwargs) # type: ignore
if method == "DELETE":
return await client.delete(url, *args, **kwargs) # type: ignore
raise Exception("Shouldn't come here")
except AsyncLibraryNotFoundError:
# Retry
loop = create_or_get_event_loop()
return loop.run_until_complete(
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
)

async def get_api_version(self):
if Querier.api_version is not None:
return Querier.api_version
Expand All @@ -79,12 +110,11 @@ async def get_api_version(self):
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
)

async def f(url: str) -> Response:
async def f(url: str, method: str) -> Response:
headers = {}
if Querier.__api_key is not None:
headers = {API_KEY_HEADER: Querier.__api_key}
async with AsyncClient() as client:
return await client.get(url, headers=headers) # type:ignore
return await self.api_request(url, method, 2, headers=headers)

response = await self.__send_request_helper(
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
Expand Down Expand Up @@ -134,13 +164,14 @@ async def send_get_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.get( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "GET", f, len(self.__hosts))

Expand All @@ -163,9 +194,14 @@ async def send_post_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.post(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
json=data,
)

return await self.__send_request_helper(path, "POST", f, len(self.__hosts))

Expand All @@ -175,13 +211,14 @@ async def send_delete_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.delete( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))

Expand All @@ -194,9 +231,8 @@ async def send_put_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.put(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(url, method, 2, headers=headers, json=data)

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))

Expand All @@ -223,7 +259,7 @@ async def __send_request_helper(
self,
path: NormalisedURLPath,
method: str,
http_function: Callable[[str], Awaitable[Response]],
http_function: Callable[[str, str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
Expand Down Expand Up @@ -253,7 +289,7 @@ async def __send_request_helper(
ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
response = await http_function(url)
response = await http_function(url, method)
if ("SUPERTOKENS_ENV" in environ) and (
environ["SUPERTOKENS_ENV"] == "testing"
):
Expand Down Expand Up @@ -289,7 +325,6 @@ async def __send_request_helper(
return response.json()
except JSONDecodeError:
return response.text

except (ConnectionError, NetworkError, ConnectTimeout) as _:
return await self.__send_request_helper(
path, method, http_function, no_of_tries - 1, retry_info_map
Expand Down
5 changes: 2 additions & 3 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from httpx import HTTPStatusError, Response
from tldextract import extract # type: ignore

from supertokens_python.async_to_sync_wrapper import check_event_loop
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
Expand Down Expand Up @@ -212,8 +212,7 @@ def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
if real_mode == "wsgi":
asyncio.run(func())
else:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
loop.create_task(func())


Expand Down

0 comments on commit 7c74be5

Please sign in to comment.