diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 2d60c50a..f9749221 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -5,7 +5,7 @@ parse_jwt_without_signature_verification, ) from supertokens_python.types import RecipeUserId -from utils import deserialize_validator +from utils import deserialize_validator, get_max_version from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session @@ -19,7 +19,18 @@ def create_new_session_without_request_response(): # type: ignore return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data.get("tenantId", "public") - user_id = data["userId"] + from supertokens_python import convert_to_recipe_user_id + + fdi_version = request.headers.get("fdi-version") + assert fdi_version is not None + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + # fdi_version <= "1.17" or (fdi_version >= "2.0" and fdi_version < "3.0") + recipe_user_id = convert_to_recipe_user_id(data["userId"]) + else: + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) access_token_payload = data.get("accessTokenPayload", {}) session_data_in_database = data.get("sessionDataInDatabase", {}) disable_anti_csrf = data.get("disableAntiCsrf") @@ -27,7 +38,7 @@ def create_new_session_without_request_response(): # type: ignore session_container = session.create_new_session_without_request_response( tenant_id, - RecipeUserId(user_id), + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 3ccc84a4..073d006f 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -39,3 +39,20 @@ def toSnakeCase(camel_case: str) -> str: else: result += char return result + + +def get_max_version(v1: str, v2: str) -> str: + v1_split = v1.split(".") + v2_split = v2.split(".") + max_loop = min(len(v1_split), len(v2_split)) + + for i in range(max_loop): + if int(v1_split[i]) > int(v2_split[i]): + return v1 + if int(v2_split[i]) > int(v1_split[i]): + return v2 + + if len(v1_split) > len(v2_split): + return v1 + + return v2