diff --git a/CHANGELOG.md b/CHANGELOG.md index 70a26f233..ef48052e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.10.5] - 2023-09-01 + +- Add logic to retry network calls if the core returns status 429 + ## [0.10.4] - 2022-08-30 ## Features: - Add support for User ID Mapping using `create_user_id_mapping`, `get_user_id_mapping`, `delete_user_id_mapping`, `update_or_delete_user_id_mapping` functions @@ -20,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.10.2] - 2022-07-14 ### Bug fix -- Make `user_context` optional in userroles recipe syncio functions. +- Make `user_context` optional in userroles recipe syncio functions. ## [0.10.1] - 2022-07-11 @@ -555,4 +559,4 @@ init( - Middleware, error handlers and verify session for each framework. - Created a wrapper for async to sync for supporting older version of python web frameworks. - Base tests for each framework. -- New requirements in the setup file. +- New requirements in the setup file. diff --git a/addDevTag b/addDevTag index 1fd4f670e..2a0dd0df2 100755 --- a/addDevTag +++ b/addDevTag @@ -1,11 +1,5 @@ #!/bin/bash -# check if we need to merge master into this branch------------ -if [[ $(git log origin/master ^HEAD) ]]; then - echo "You need to merge master into this branch. Exiting" - exit 1 -fi - # get version------------ version=`cat setup.py | grep -e 'version='` while IFS='"' read -ra ADDR; do @@ -86,4 +80,4 @@ fi git tag dev-v$version $commit_hash -git push --tags \ No newline at end of file +git push --tags diff --git a/setup.py b/setup.py index 8472a3ac7..a81829a18 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.10.4", + version="0.10.5", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 965e6d1c2..99251cfd4 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. SUPPORTED_CDI_VERSIONS = ["2.9", "2.10", "2.11", "2.12", "2.13", "2.14", "2.15"] -VERSION = "0.10.4" +VERSION = "0.10.5" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" @@ -25,3 +25,4 @@ FDI_KEY_HEADER = "fdi-version" API_VERSION = "/apiversion" API_VERSION_HEADER = "cdi-version" +RATE_LIMIT_STATUS_CODE = 429 diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 0c8deff63..2653d7499 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -13,9 +13,11 @@ # under the License. from __future__ import annotations +import asyncio + from json import JSONDecodeError from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from httpx import AsyncClient, ConnectTimeout, NetworkError, Response @@ -25,6 +27,7 @@ API_VERSION_HEADER, RID_KEY_HEADER, SUPPORTED_CDI_VERSIONS, + RATE_LIMIT_STATUS_CODE, ) from .normalised_url_path import NormalisedURLPath @@ -42,7 +45,7 @@ class Querier: __init_called = False __hosts: List[Host] = [] __api_key: Union[None, str] = None - __api_version = None + api_version = None __last_tried_index: int = 0 __hosts_alive_for_testing: Set[str] = set() @@ -69,8 +72,8 @@ def get_hosts_alive_for_testing(): return Querier.__hosts_alive_for_testing async def get_api_version(self): - if Querier.__api_version is not None: - return Querier.__api_version + if Querier.api_version is not None: + return Querier.api_version ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION @@ -96,8 +99,8 @@ async def f(url: str) -> Response: "to find the right versions" ) - Querier.__api_version = api_version - return Querier.__api_version + Querier.api_version = api_version + return Querier.api_version @staticmethod def get_instance(rid_to_core: Union[str, None] = None): @@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None): Querier.__init_called = True Querier.__hosts = hosts Querier.__api_key = api_key - Querier.__api_version = None + Querier.api_version = None Querier.__last_tried_index = 0 Querier.__hosts_alive_for_testing = set() @@ -196,6 +199,7 @@ async def __send_request_helper( method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -212,6 +216,14 @@ async def __send_request_helper( Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -221,6 +233,20 @@ async def __send_request_helper( ): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -238,9 +264,9 @@ async def __send_request_helper( except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) diff --git a/tests/test_querier.py b/tests/test_querier.py new file mode 100644 index 000000000..8c0a32848 --- /dev/null +++ b/tests/test_querier.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from pytest import mark +from supertokens_python.recipe import ( + session, + emailpassword, +) +import asyncio +import respx +import httpx +from supertokens_python import init, SupertokensConfig +from supertokens_python.querier import Querier, NormalisedURLPath + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio +respx_mock = respx.MockRouter + + +async def test_network_call_is_retried_as_expected(): + # Test that network call is retried properly + # Test that rate limiting errors are thrown back to the user + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + ] + ) + args["supertokens_config"] = SupertokensConfig("http://localhost:6789") + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + api2_call_count = 0 + + def api2_side_effect(_: httpx.Request): + nonlocal api2_call_count + api2_call_count += 1 + + if api2_call_count == 3: + return httpx.Response(200) + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api1 = mocker.get("http://localhost:6789/api1").mock( + httpx.Response(429, json={"status": "RATE_ERROR"}) + ) + api2 = mocker.get("http://localhost:6789/api2").mock( + side_effect=api2_side_effect + ) + api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200)) + + try: + await q.send_get_request(NormalisedURLPath("/api1"), {}) + except Exception as e: + if "with status code: 429" in str( + e + ) and 'message: {"status": "RATE_ERROR"}' in str(e): + pass + else: + raise e + + await q.send_get_request(NormalisedURLPath("/api2"), {}) + await q.send_get_request(NormalisedURLPath("/api3"), {}) + + # 1 initial request + 5 retries + assert api1.call_count == 6 + # 2 403 and 1 200 + assert api2.call_count == 3 + # 200 in the first attempt + assert api3.call_count == 1 + + +async def test_parallel_calls_have_independent_counters(): + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + ] + ) + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + call_count1 = 0 + call_count2 = 0 + + def api_side_effect(r: httpx.Request): + nonlocal call_count1, call_count2 + + id_ = int(r.url.params.get("id")) + if id_ == 1: + call_count1 += 1 + elif id_ == 2: + call_count2 += 1 + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect) + + async def call_api(id_: int): + try: + await q.send_get_request(NormalisedURLPath("/api"), {"id": id_}) + except Exception as e: + if "with status code: 429" in str(e): + pass + else: + raise e + + _ = await asyncio.gather( + call_api(1), + call_api(2), + ) + + # 1 initial request + 5 retries + assert call_count1 == 6 + assert call_count2 == 6 + + assert api.call_count == 12 diff --git a/tests/utils.py b/tests/utils.py index 815544570..7aebc1087 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,7 +24,7 @@ from requests.models import Response from yaml import FullLoader, dump, load -from supertokens_python import Supertokens +from supertokens_python import SupertokensConfig, Supertokens, InputAppInfo from supertokens_python.process_state import ProcessState from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailverification import EmailVerificationRecipe @@ -359,3 +359,31 @@ def email_verify_token_request( finally: if use_server: environ["SUPERTOKENS_ENV"] = "testing" + + +def setup_function(_: Any) -> None: + reset() + clean_st() + setup_st() + + +def teardown_function(_: Any) -> None: + reset() + clean_st() + + +st_init_common_args = { + "supertokens_config": SupertokensConfig("http://localhost:3567"), + "app_info": InputAppInfo( + app_name="ST", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + "framework": "fastapi", + "mode": "asgi", +} + + +def get_st_init_args(recipe_list: List[Any]) -> Dict[str, Any]: + return {**st_init_common_args, "recipe_list": recipe_list}