Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add 429 rate limting from SaaS for v0.10 #436

Merged
merged 5 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.10.5] - 2023-09-01

- Add logic to retry network calls if the core returns status 429

## [0.10.4] - 2022-08-30
## Features:
- Add support for User ID Mapping using `create_user_id_mapping`, `get_user_id_mapping`, `delete_user_id_mapping`, `update_or_delete_user_id_mapping` functions
Expand All @@ -20,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.10.2] - 2022-07-14
### Bug fix
- Make `user_context` optional in userroles recipe syncio functions.
- Make `user_context` optional in userroles recipe syncio functions.

## [0.10.1] - 2022-07-11

Expand Down Expand Up @@ -555,4 +559,4 @@ init(
- Middleware, error handlers and verify session for each framework.
- Created a wrapper for async to sync for supporting older version of python web frameworks.
- Base tests for each framework.
- New requirements in the setup file.
- New requirements in the setup file.
8 changes: 1 addition & 7 deletions addDevTag
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
#!/bin/bash

# check if we need to merge master into this branch------------
if [[ $(git log origin/master ^HEAD) ]]; then
echo "You need to merge master into this branch. Exiting"
exit 1
fi

# get version------------
version=`cat setup.py | grep -e 'version='`
while IFS='"' read -ra ADDR; do
Expand Down Expand Up @@ -86,4 +80,4 @@ fi

git tag dev-v$version $commit_hash

git push --tags
git push --tags
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.10.4",
version="0.10.5",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
SUPPORTED_CDI_VERSIONS = ["2.9", "2.10", "2.11", "2.12", "2.13", "2.14", "2.15"]
VERSION = "0.10.4"
VERSION = "0.10.5"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand All @@ -25,3 +25,4 @@
FDI_KEY_HEADER = "fdi-version"
API_VERSION = "/apiversion"
API_VERSION_HEADER = "cdi-version"
RATE_LIMIT_STATUS_CODE = 429
44 changes: 35 additions & 9 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# under the License.
from __future__ import annotations

import asyncio

from json import JSONDecodeError
from os import environ
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional

from httpx import AsyncClient, ConnectTimeout, NetworkError, Response

Expand All @@ -25,6 +27,7 @@
API_VERSION_HEADER,
RID_KEY_HEADER,
SUPPORTED_CDI_VERSIONS,
RATE_LIMIT_STATUS_CODE,
)
from .normalised_url_path import NormalisedURLPath

Expand All @@ -42,7 +45,7 @@ class Querier:
__init_called = False
__hosts: List[Host] = []
__api_key: Union[None, str] = None
__api_version = None
api_version = None
__last_tried_index: int = 0
__hosts_alive_for_testing: Set[str] = set()

Expand All @@ -69,8 +72,8 @@ def get_hosts_alive_for_testing():
return Querier.__hosts_alive_for_testing

async def get_api_version(self):
if Querier.__api_version is not None:
return Querier.__api_version
if Querier.api_version is not None:
return Querier.api_version

ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
Expand All @@ -96,8 +99,8 @@ async def f(url: str) -> Response:
"to find the right versions"
)

Querier.__api_version = api_version
return Querier.__api_version
Querier.api_version = api_version
return Querier.api_version

@staticmethod
def get_instance(rid_to_core: Union[str, None] = None):
Expand All @@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None):
Querier.__init_called = True
Querier.__hosts = hosts
Querier.__api_key = api_key
Querier.__api_version = None
Querier.api_version = None
Querier.__last_tried_index = 0
Querier.__hosts_alive_for_testing = set()

Expand Down Expand Up @@ -196,6 +199,7 @@ async def __send_request_helper(
method: str,
http_function: Callable[[str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
if no_of_tries == 0:
raise_general_exception("No SuperTokens core available to query")
Expand All @@ -212,6 +216,14 @@ async def __send_request_helper(
Querier.__last_tried_index %= len(self.__hosts)
url = current_host + path.get_as_string_dangerous()

max_retries = 5

if retry_info_map is None:
retry_info_map = {}

if retry_info_map.get(url) is None:
retry_info_map[url] = max_retries

ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
Expand All @@ -221,6 +233,20 @@ async def __send_request_helper(
):
Querier.__hosts_alive_for_testing.add(current_host)

if response.status_code == RATE_LIMIT_STATUS_CODE:
retries_left = retry_info_map[url]

if retries_left > 0:
retry_info_map[url] = retries_left - 1

attempts_made = max_retries - retries_left
delay = (10 + attempts_made * 250) / 1000

await asyncio.sleep(delay)
return await self.__send_request_helper(
path, method, http_function, no_of_tries, retry_info_map
)

if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
raise_general_exception(
"SuperTokens core threw an error for a "
Expand All @@ -238,9 +264,9 @@ async def __send_request_helper(
except JSONDecodeError:
return response.text

except (ConnectionError, NetworkError, ConnectTimeout):
except (ConnectionError, NetworkError, ConnectTimeout) as _:
return await self.__send_request_helper(
path, method, http_function, no_of_tries - 1
path, method, http_function, no_of_tries - 1, retry_info_map
)
except Exception as e:
raise_general_exception(e)
144 changes: 144 additions & 0 deletions tests/test_querier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from pytest import mark
from supertokens_python.recipe import (
session,
emailpassword,
)
import asyncio
import respx
import httpx
from supertokens_python import init, SupertokensConfig
from supertokens_python.querier import Querier, NormalisedURLPath

from tests.utils import get_st_init_args
from tests.utils import (
setup_function,
teardown_function,
start_st,
)

_ = setup_function
_ = teardown_function

pytestmark = mark.asyncio
respx_mock = respx.MockRouter


async def test_network_call_is_retried_as_expected():
# Test that network call is retried properly
# Test that rate limiting errors are thrown back to the user
args = get_st_init_args(
[
session.init(),
emailpassword.init(),
]
)
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
init(**args) # type: ignore
start_st()

Querier.api_version = "3.0"
q = Querier.get_instance()

api2_call_count = 0

def api2_side_effect(_: httpx.Request):
nonlocal api2_call_count
api2_call_count += 1

if api2_call_count == 3:
return httpx.Response(200)

return httpx.Response(429, json={})

with respx_mock() as mocker:
api1 = mocker.get("http://localhost:6789/api1").mock(
httpx.Response(429, json={"status": "RATE_ERROR"})
)
api2 = mocker.get("http://localhost:6789/api2").mock(
side_effect=api2_side_effect
)
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))

try:
await q.send_get_request(NormalisedURLPath("/api1"), {})
except Exception as e:
if "with status code: 429" in str(
e
) and 'message: {"status": "RATE_ERROR"}' in str(e):
pass
else:
raise e

await q.send_get_request(NormalisedURLPath("/api2"), {})
await q.send_get_request(NormalisedURLPath("/api3"), {})

# 1 initial request + 5 retries
assert api1.call_count == 6
# 2 403 and 1 200
assert api2.call_count == 3
# 200 in the first attempt
assert api3.call_count == 1


async def test_parallel_calls_have_independent_counters():
args = get_st_init_args(
[
session.init(),
emailpassword.init(),
]
)
init(**args) # type: ignore
start_st()

Querier.api_version = "3.0"
q = Querier.get_instance()

call_count1 = 0
call_count2 = 0

def api_side_effect(r: httpx.Request):
nonlocal call_count1, call_count2

id_ = int(r.url.params.get("id"))
if id_ == 1:
call_count1 += 1
elif id_ == 2:
call_count2 += 1

return httpx.Response(429, json={})

with respx_mock() as mocker:
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)

async def call_api(id_: int):
try:
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
except Exception as e:
if "with status code: 429" in str(e):
pass
else:
raise e

_ = await asyncio.gather(
call_api(1),
call_api(2),
)

# 1 initial request + 5 retries
assert call_count1 == 6
assert call_count2 == 6

assert api.call_count == 12
30 changes: 29 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from requests.models import Response
from yaml import FullLoader, dump, load

from supertokens_python import Supertokens
from supertokens_python import SupertokensConfig, Supertokens, InputAppInfo
from supertokens_python.process_state import ProcessState
from supertokens_python.recipe.emailpassword import EmailPasswordRecipe
from supertokens_python.recipe.emailverification import EmailVerificationRecipe
Expand Down Expand Up @@ -359,3 +359,31 @@ def email_verify_token_request(
finally:
if use_server:
environ["SUPERTOKENS_ENV"] = "testing"


def setup_function(_: Any) -> None:
reset()
clean_st()
setup_st()


def teardown_function(_: Any) -> None:
reset()
clean_st()


st_init_common_args = {
"supertokens_config": SupertokensConfig("http://localhost:3567"),
"app_info": InputAppInfo(
app_name="ST",
api_domain="http://api.supertokens.io",
website_domain="http://supertokens.io",
api_base_path="/auth",
),
"framework": "fastapi",
"mode": "asgi",
}


def get_st_init_args(recipe_list: List[Any]) -> Dict[str, Any]:
return {**st_init_common_args, "recipe_list": recipe_list}
Loading