diff --git a/ruff.toml b/ruff.toml index 3b563fd09..7dba2fc3f 100644 --- a/ruff.toml +++ b/ruff.toml @@ -3,6 +3,10 @@ line-length = 120 target-version = "py310" [lint] -select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA"] +select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN"] -ignore = ["E501"] +# ANN101 and ANN102 are depracated +ignore = ["E501", "ANN1"] + +[lint.flake8-annotations] +allow-star-arg-any = true diff --git a/src/eduid/common/clients/amapi_client/amapi_client.py b/src/eduid/common/clients/amapi_client/amapi_client.py index bcf6a1b6f..e6d93db71 100644 --- a/src/eduid/common/clients/amapi_client/amapi_client.py +++ b/src/eduid/common/clients/amapi_client/amapi_client.py @@ -9,6 +9,7 @@ __author__ = "masv" from eduid.common.models.amapi_user import ( + UserBaseRequest, UserUpdateEmailRequest, UserUpdateLanguageRequest, UserUpdateMetaCleanedRequest, @@ -20,14 +21,14 @@ class AMAPIClient(GNAPClient): - def __init__(self, amapi_url: str, auth_data=GNAPClientAuthData, verify_tls: bool = True, **kwargs): + def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs: Any) -> None: super().__init__(auth_data=auth_data, verify=verify_tls, **kwargs) self.amapi_url = amapi_url def _users_base_url(self) -> str: return urlappend(self.amapi_url, "users") - def _put(self, base_path: str, user: str, endpoint: str, body: Any) -> httpx.Response: + def _put(self, base_path: str, user: str, endpoint: str, body: UserBaseRequest) -> httpx.Response: return self.put(url=urlappend(base_path, f"{user}/{endpoint}"), content=body.json()) def update_user_email(self, user: str, body: UserUpdateEmailRequest) -> UserUpdateResponse: diff --git a/src/eduid/common/clients/amapi_client/testing.py b/src/eduid/common/clients/amapi_client/testing.py index d6b911dad..2ea93dd52 100644 --- a/src/eduid/common/clients/amapi_client/testing.py +++ b/src/eduid/common/clients/amapi_client/testing.py @@ -4,7 +4,7 @@ class MockedAMAPIMixin(MockedSyncAuthAPIMixin): - def start_mock_amapi(self, access_token_value: str | None = None): + def start_mock_amapi(self, access_token_value: str | None = None) -> None: self.start_mock_auth_api(access_token_value=access_token_value) self.mocked_users = respx.mock(base_url="http://localhost", assert_all_called=False) diff --git a/src/eduid/common/clients/gnap_client/async_client.py b/src/eduid/common/clients/gnap_client/async_client.py index 24ba8c425..aff57c0c5 100644 --- a/src/eduid/common/clients/gnap_client/async_client.py +++ b/src/eduid/common/clients/gnap_client/async_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -11,7 +12,7 @@ class AsyncGNAPClient(httpx.AsyncClient, GNAPBearerTokenMixin): - def __init__(self, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any) -> None: if "event_hooks" not in kwargs: kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]} diff --git a/src/eduid/common/clients/gnap_client/sync_client.py b/src/eduid/common/clients/gnap_client/sync_client.py index 79ff1cc1a..dbb7b13a5 100644 --- a/src/eduid/common/clients/gnap_client/sync_client.py +++ b/src/eduid/common/clients/gnap_client/sync_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -11,7 +12,7 @@ class GNAPClient(httpx.Client, GNAPBearerTokenMixin): - def __init__(self, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any) -> None: if "event_hooks" not in kwargs: kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]} diff --git a/src/eduid/common/clients/gnap_client/testing.py b/src/eduid/common/clients/gnap_client/testing.py index 62cf913ce..ae5e7fc99 100644 --- a/src/eduid/common/clients/gnap_client/testing.py +++ b/src/eduid/common/clients/gnap_client/testing.py @@ -7,7 +7,7 @@ class MockedSyncAuthAPIMixin: - def start_mock_auth_api(self, access_token_value: str | None = None): + def start_mock_auth_api(self, access_token_value: str | None = None) -> None: if access_token_value is None: access_token_value = "mock_jwt" self.mocked_auth_api = respx.mock(base_url="http://localhost/auth", assert_all_called=False) diff --git a/src/eduid/common/clients/scim_client/scim_client.py b/src/eduid/common/clients/scim_client/scim_client.py index d66eaa2a7..0a25e450c 100644 --- a/src/eduid/common/clients/scim_client/scim_client.py +++ b/src/eduid/common/clients/scim_client/scim_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any from uuid import UUID import httpx @@ -20,7 +21,7 @@ class SCIMError(Exception): class SCIMClient(GNAPClient): - def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs: Any) -> None: super().__init__(auth_data=auth_data, **kwargs) self.event_hooks["request"].append(self._add_accept_header) self.scim_api_url = scim_api_url diff --git a/src/eduid/common/clients/scim_client/testing.py b/src/eduid/common/clients/scim_client/testing.py index 9a0af7155..5c8ff6e93 100644 --- a/src/eduid/common/clients/scim_client/testing.py +++ b/src/eduid/common/clients/scim_client/testing.py @@ -79,7 +79,7 @@ class MockedScimAPIMixin(MockedSyncAuthAPIMixin): put_user_response = post_user_response - def start_mocked_scim_api(self): + def start_mocked_scim_api(self) -> None: self.start_mock_auth_api() self.mocked_scim_api = respx.mock(base_url="http://localhost/scim", assert_all_called=False) diff --git a/src/eduid/common/config/exceptions.py b/src/eduid/common/config/exceptions.py index 5c7d7d097..dd6fe74ef 100644 --- a/src/eduid/common/config/exceptions.py +++ b/src/eduid/common/config/exceptions.py @@ -1,7 +1,7 @@ class BadConfiguration(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: Exception.__init__(self) self.value = message - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/src/eduid/common/config/parsers/decorators.py b/src/eduid/common/config/parsers/decorators.py index ce686111f..1aa99e85f 100644 --- a/src/eduid/common/config/parsers/decorators.py +++ b/src/eduid/common/config/parsers/decorators.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import wraps from string import Template from typing import Any @@ -11,9 +11,9 @@ from eduid.common.config.parsers.exceptions import SecretKeyException -def decrypt(f): +def decrypt(f: Callable) -> Callable: @wraps(f) - def decrypt_decorator(*args, **kwargs): + def decrypt_decorator(*args: Any, **kwargs: Any) -> Mapping[str, Any]: config_dict = f(*args, **kwargs) decrypted_config_dict = decrypt_config(config_dict) return decrypted_config_dict @@ -83,9 +83,9 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]: return new_config_dict -def interpolate(f): +def interpolate(f: Callable) -> Callable: @wraps(f) - def interpolation_decorator(*args, **kwargs): + def interpolation_decorator(*args: Any, **kwargs: Any) -> dict[str, Any]: config_dict = f(*args, **kwargs) interpolated_config_dict = interpolate_config(config_dict) for key in list(interpolated_config_dict.keys()): diff --git a/src/eduid/common/config/parsers/exceptions.py b/src/eduid/common/config/parsers/exceptions.py index eaecc6f93..a4f2c4e7a 100644 --- a/src/eduid/common/config/parsers/exceptions.py +++ b/src/eduid/common/config/parsers/exceptions.py @@ -2,11 +2,11 @@ class ParserException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: Exception.__init__(self) self.value = message - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/src/eduid/common/config/parsers/yaml_parser.py b/src/eduid/common/config/parsers/yaml_parser.py index de2a4768d..32c591319 100644 --- a/src/eduid/common/config/parsers/yaml_parser.py +++ b/src/eduid/common/config/parsers/yaml_parser.py @@ -7,7 +7,7 @@ class YamlConfigParser(BaseConfigParser): - def __init__(self, path: Path): + def __init__(self, path: Path) -> None: self.path = path @interpolate diff --git a/src/eduid/common/config/tests/test_config_parser.py b/src/eduid/common/config/tests/test_config_parser.py index ae52ecb98..e943dcf6a 100644 --- a/src/eduid/common/config/tests/test_config_parser.py +++ b/src/eduid/common/config/tests/test_config_parser.py @@ -8,10 +8,10 @@ class TestInitConfig(unittest.TestCase): - def tearDown(self): + def tearDown(self) -> None: os.environ.clear() - def test_YamlConfigParser(self): + def test_YamlConfigParser(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/test/ns/" os.environ["EDUID_CONFIG_YAML"] = "/config.yaml" parser = _choose_parser() diff --git a/src/eduid/common/config/tests/test_yaml_parser.py b/src/eduid/common/config/tests/test_yaml_parser.py index 7df9545bb..b6bd4297b 100644 --- a/src/eduid/common/config/tests/test_yaml_parser.py +++ b/src/eduid/common/config/tests/test_yaml_parser.py @@ -25,10 +25,10 @@ class TestInitConfig(unittest.TestCase): def setUp(self) -> None: self.data_dir = PurePath(__file__).with_name("data") - def tearDown(self): + def tearDown(self) -> None: os.environ.clear() - def test_YamlConfig(self): + def test_YamlConfig(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/app_one" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -49,7 +49,7 @@ def test_YamlConfig(self): assert config_two.number == 10 assert config_two.only_default == 19 - def test_YamlConfig_interpolation(self): + def test_YamlConfig_interpolation(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_interpolation" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -58,7 +58,7 @@ def test_YamlConfig_interpolation(self): assert config.number == 3 assert config.foo == "hi world" - def test_YamlConfig_missing_value(self): + def test_YamlConfig_missing_value(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_missing_value" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -75,7 +75,7 @@ def test_YamlConfig_missing_value(self): } ], f"Wrong error message: {exc_info.value.errors()}" - def test_YamlConfig_wrong_type(self): + def test_YamlConfig_wrong_type(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_wrong_type" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -92,7 +92,7 @@ def test_YamlConfig_wrong_type(self): } ], f"Wrong error message: {exc_info.value.errors()}" - def test_YamlConfig_unknown_data(self): + def test_YamlConfig_unknown_data(self) -> None: """Unknown data should not be rejected because it is an operational nightmare""" os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_unknown_data" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" @@ -102,7 +102,7 @@ def test_YamlConfig_unknown_data(self): assert config.number == 0xFF assert config.foo == "bar" - def test_YamlConfig_mixed_case_keys(self): + def test_YamlConfig_mixed_case_keys(self) -> None: """For legacy reasons, all keys should be lowercased""" os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_mixed_case_keys" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" diff --git a/src/eduid/common/decorators.py b/src/eduid/common/decorators.py index 2c0e62e66..a2c135bce 100644 --- a/src/eduid/common/decorators.py +++ b/src/eduid/common/decorators.py @@ -1,10 +1,12 @@ import inspect import warnings +from collections.abc import Callable from functools import wraps +from typing import Any # https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488 -def deprecated(reason): +def deprecated(reason: str | type | Callable) -> Callable: """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted @@ -20,14 +22,14 @@ def deprecated(reason): # def old_function(x, y): # pass - def decorator(func1): + def decorator(func1: Callable) -> Callable: if inspect.isclass(func1): fmt1 = "Call to deprecated class {name} ({reason})." else: fmt1 = "Call to deprecated function {name} ({reason})." @wraps(func1) - def new_func1(*args, **kwargs): + def new_func1(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 warnings.simplefilter("always", DeprecationWarning) warnings.warn( fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2 @@ -56,7 +58,7 @@ def new_func1(*args, **kwargs): fmt2 = "Call to deprecated function {name}." @wraps(func2) - def new_func2(*args, **kwargs): + def new_func2(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 warnings.simplefilter("always", DeprecationWarning) warnings.warn(fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2) warnings.simplefilter("default", DeprecationWarning) diff --git a/src/eduid/common/fastapi/context_request.py b/src/eduid/common/fastapi/context_request.py index c4953b007..ff32ddb93 100644 --- a/src/eduid/common/fastapi/context_request.py +++ b/src/eduid/common/fastapi/context_request.py @@ -1,5 +1,6 @@ from collections.abc import Callable from dataclasses import asdict, dataclass +from typing import Any from fastapi import Request, Response from fastapi.routing import APIRoute @@ -7,17 +8,17 @@ @dataclass class Context: - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return asdict(self) class ContextRequest(Request): - def __init__(self, context_class: type[Context], *args, **kwargs): + def __init__(self, context_class: type[Context], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.contextClass = context_class @property - def context(self): + def context(self) -> Context: try: return self.state.context except AttributeError: @@ -26,7 +27,7 @@ def context(self): return self.context @context.setter - def context(self, context: Context): + def context(self, context: Context) -> None: self.state.context = context diff --git a/src/eduid/common/fastapi/exceptions.py b/src/eduid/common/fastapi/exceptions.py index e150698a9..ddfac88d1 100644 --- a/src/eduid/common/fastapi/exceptions.py +++ b/src/eduid/common/fastapi/exceptions.py @@ -2,6 +2,7 @@ import logging import uuid +from typing import Any from fastapi import Request, status from fastapi.exception_handlers import http_exception_handler @@ -56,7 +57,7 @@ def __init__( self, status_code: int, detail: str | None = None, - ): + ) -> None: self._error_detail = ErrorDetail(detail=detail, status=status_code) self._extra_headers: dict | None = None @@ -69,33 +70,33 @@ def extra_headers(self) -> dict | None: return self._extra_headers @extra_headers.setter - def extra_headers(self, headers: dict): + def extra_headers(self, headers: dict) -> None: self._extra_headers = headers class BadRequest(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_400_BAD_REQUEST, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Bad Request" class Unauthorized(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Unauthorized request" class NotFound(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_404_NOT_FOUND, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Resource not found" class MethodNotAllowedMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, **kwargs) if not self.error_detail.detail: allowed_methods = kwargs.get("allowed_methods") @@ -103,7 +104,7 @@ def __init__(self, **kwargs): class UnsupportedMediaTypeMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request was made with an unsupported media type" diff --git a/src/eduid/common/fastapi/utils.py b/src/eduid/common/fastapi/utils.py index 820878ab8..659291e47 100644 --- a/src/eduid/common/fastapi/utils.py +++ b/src/eduid/common/fastapi/utils.py @@ -26,7 +26,7 @@ class FailCountItem: exit_at: datetime | None = None count: int = 0 - def __str__(self): + def __str__(self) -> str: return f"(first_failure: {self.first_failure.isoformat()}, fail count: {self.count})" @@ -48,7 +48,7 @@ def reset_failure_info(req: ContextRequest, key: str) -> None: req.app.context.logger.info(f"Check {key} back to normal. Resetting info {info}") -def check_restart(key, restart: int, terminate: int) -> bool: +def check_restart(key: str, restart: int, terminate: int) -> bool: res = False # default to no restart info = FAILURE_INFO.get(key) if not info: diff --git a/src/eduid/common/logging.py b/src/eduid/common/logging.py index 84948ac5b..923f46ef2 100644 --- a/src/eduid/common/logging.py +++ b/src/eduid/common/logging.py @@ -26,11 +26,11 @@ # Default to RFC3339/ISO 8601 with tz class EduidFormatter(logging.Formatter): - def __init__(self, relative_time: bool = False, fmt=None): + def __init__(self, relative_time: bool = False, fmt: str | None = None) -> None: super().__init__(fmt=fmt, style="{") self._relative_time = relative_time - def formatTime(self, record: logging.LogRecord, datefmt=None) -> str: + def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: if self._relative_time: # Relative time makes much more sense than absolute time when running tests for example _seconds = record.relativeCreated / 1000 @@ -52,7 +52,7 @@ def formatTime(self, record: logging.LogRecord, datefmt=None) -> str: class AppFilter(logging.Filter): """Add `system_hostname`, `hostname` and `app_name` to records being logged.""" - def __init__(self, app_name): + def __init__(self, app_name: str) -> None: super().__init__() self.app_name = app_name # TODO: I guess it could be argued that these should be put in the LocalContext and not evaluated at runtime. @@ -89,7 +89,7 @@ class UserFilter(logging.Filter): This allows us to debug-log certain users in production, without having debug logging enabled for everyone. """ - def __init__(self, debug_eppns: Sequence[str]): + def __init__(self, debug_eppns: Sequence[str]) -> None: super().__init__() self.debug_eppns = debug_eppns @@ -117,7 +117,7 @@ def filter(self, record: logging.LogRecord) -> bool: class RequireDebugTrue(logging.Filter): """A filter to discard all debug log records if the Flask app.debug is not True. Generally not used.""" - def __init__(self, app_debug: bool): + def __init__(self, app_debug: bool) -> None: super().__init__() self.app_debug = app_debug @@ -128,7 +128,7 @@ def filter(self, record: logging.LogRecord) -> bool: class RequireDebugFalse(logging.Filter): """A filter to discard all debug log records if the Flask app.debug is not False. Generally not used.""" - def __init__(self, app_debug: bool): + def __init__(self, app_debug: bool) -> None: super().__init__() self.app_debug = app_debug @@ -139,7 +139,7 @@ def filter(self, record: logging.LogRecord) -> bool: def merge_config(base_config: dict[str, Any], new_config: dict[str, Any]) -> dict[str, Any]: """Recursively merge two dictConfig dicts.""" - def merge(node, key, value): + def merge(node: dict[str, Any], key: str, value: object) -> None: if isinstance(value, dict): for item in value: if key in node: diff --git a/src/eduid/common/misc/encoders.py b/src/eduid/common/misc/encoders.py index 3a366eba6..318124f19 100644 --- a/src/eduid/common/misc/encoders.py +++ b/src/eduid/common/misc/encoders.py @@ -11,7 +11,7 @@ class EduidJSONEncoder(json.JSONEncoder): # TODO: This enables us to serialise NameIDs into the stored sessions, # but we don't seem to de-serialise them on load - def default(self, o: Any) -> str | Any: + def default(self, o: Any) -> str | Any: # noqa: ANN401 if isinstance(o, datetime): return o.isoformat() if isinstance(o, timedelta): diff --git a/src/eduid/common/misc/tests/test_timeutil.py b/src/eduid/common/misc/tests/test_timeutil.py index 017c4fa9c..d395394df 100644 --- a/src/eduid/common/misc/tests/test_timeutil.py +++ b/src/eduid/common/misc/tests/test_timeutil.py @@ -4,7 +4,7 @@ class TimeUtilTests(unittest.TestCase): - def test_utc_now(self): + def test_utc_now(self) -> None: t1 = utc_now() t2 = utc_now() assert t2 > t1 diff --git a/src/eduid/common/models/bearer_token.py b/src/eduid/common/models/bearer_token.py index 206186bfa..50e20dcd8 100644 --- a/src/eduid/common/models/bearer_token.py +++ b/src/eduid/common/models/bearer_token.py @@ -53,7 +53,7 @@ class AuthnBearerToken(BaseModel): saml_eppn: str | None = None saml_unique_id: str | None = None - def __str__(self): + def __str__(self) -> str: return f"<{self.__class__.__name__}: scopes={self.scopes}, requested_access={self.requested_access}>" @field_validator("version") @@ -65,7 +65,7 @@ def validate_version(cls, v: int) -> int: @model_validator(mode="before") @classmethod - def set_scopes_from_saml_data(cls, values: dict[str, Any]): + def set_scopes_from_saml_data(cls, values: dict[str, Any]) -> dict[str, Any]: # Get scope from saml identifier if the auth source is interaction and set it as scopes if values.get("auth_source") == AuthSource.INTERACTION.value: values["scopes"] = cls._get_scope_from_saml_data(values=values) diff --git a/src/eduid/common/models/generic.py b/src/eduid/common/models/generic.py index 431e79803..18d8a5d5b 100644 --- a/src/eduid/common/models/generic.py +++ b/src/eduid/common/models/generic.py @@ -16,7 +16,7 @@ # https://docs.pydantic.dev/2.6/concepts/types/#handling-third-party-types class ObjectIdPydanticAnnotation: @classmethod - def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: # noqa: ANN401 """ We return a pydantic_core.CoreSchema that behaves in the following ways: @@ -58,7 +58,7 @@ def __get_pydantic_json_schema__( class JWKPydanticAnnotation: @classmethod - def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: # noqa: ANN401 """ We return a pydantic_core.CoreSchema that behaves in the following ways: diff --git a/src/eduid/common/models/jose_models.py b/src/eduid/common/models/jose_models.py index 7d2ecd67a..9e323e161 100644 --- a/src/eduid/common/models/jose_models.py +++ b/src/eduid/common/models/jose_models.py @@ -1,5 +1,6 @@ import datetime from enum import Enum +from typing import Any from pydantic import AnyUrl, BaseModel, Field @@ -115,7 +116,7 @@ class RegisteredClaims(BaseModel): iat: datetime.datetime = Field(default_factory=utc_now) # Issued At jti: str | None = None # JWT ID - def to_rfc7519(self): + def to_rfc7519(self) -> dict[str, Any]: d = self.dict(exclude_none=True) if self.exp: d["exp"] = int((self.iat + self.exp).timestamp()) diff --git a/src/eduid/common/models/scim_base.py b/src/eduid/common/models/scim_base.py index 243dcc6be..c3951038d 100644 --- a/src/eduid/common/models/scim_base.py +++ b/src/eduid/common/models/scim_base.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Annotated, Any +from typing import Annotated, Any, TypeVar from uuid import UUID from bson import ObjectId @@ -98,21 +98,24 @@ class EduidBaseModel(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True, populate_by_name=True) +TSubResource = TypeVar("TSubResource", bound="SubResource") + + class SubResource(EduidBaseModel): value: UUID ref: str = Field(alias="$ref") display: str @property - def is_user(self): - return self.ref and "/Users/" in self.ref + def is_user(self) -> bool: + return self.ref is not None and "/Users/" in self.ref @property - def is_group(self): - return self.ref and "/Groups/" in self.ref + def is_group(self) -> bool: + return self.ref is not None and "/Groups/" in self.ref @classmethod - def from_mapping(cls, data): + def from_mapping(cls: type[TSubResource], data: object) -> TSubResource: return cls.model_validate(data) diff --git a/src/eduid/common/rpc/am_relay.py b/src/eduid/common/rpc/am_relay.py index 6e7f6da24..927ff9a23 100644 --- a/src/eduid/common/rpc/am_relay.py +++ b/src/eduid/common/rpc/am_relay.py @@ -16,7 +16,7 @@ class AmRelay: This is the interface to the RPC task to save users to the central userdb. """ - def __init__(self, config: AmConfigMixin): + def __init__(self, config: AmConfigMixin) -> None: """ :param config: celery config :param relay_for: Name of application to relay for diff --git a/src/eduid/common/rpc/lookup_mobile_relay.py b/src/eduid/common/rpc/lookup_mobile_relay.py index 302162d36..b9f25eeee 100644 --- a/src/eduid/common/rpc/lookup_mobile_relay.py +++ b/src/eduid/common/rpc/lookup_mobile_relay.py @@ -1,3 +1,5 @@ +from typing import Any + import eduid.workers.lookup_mobile __author__ = "mathiashedstrom" @@ -7,7 +9,7 @@ class LookupMobileRelay: - def __init__(self, config: CeleryConfigMixin): + def __init__(self, config: CeleryConfigMixin) -> None: self.app_name = config.app_name eduid.workers.lookup_mobile.init_app(config.celery) # these have to be imported _after_ eduid.workers.lookup_mobile.init_app() @@ -26,7 +28,7 @@ def find_nin_by_mobile(self, mobile_number: str) -> str | None: raise LookupMobileTaskFailed(f"find_nin_by_mobile task failed: {e}") @deprecated("This task seems unused") - def find_mobiles_by_nin(self, nin: str): + def find_mobiles_by_nin(self, nin: str) -> Any: # noqa: ANN401 try: result = self._find_mobiles_by_NIN.delay(nin) result = result.get(timeout=10) # Lower timeout than standard gunicorn worker timeout (25) diff --git a/src/eduid/common/rpc/mail_relay.py b/src/eduid/common/rpc/mail_relay.py index 0dd84de58..fe34282d2 100644 --- a/src/eduid/common/rpc/mail_relay.py +++ b/src/eduid/common/rpc/mail_relay.py @@ -14,7 +14,7 @@ class MailRelay: This is the interface to the RPC task to send e-mail. """ - def __init__(self, config: MailConfigMixin): + def __init__(self, config: MailConfigMixin) -> None: self.app_name = config.app_name self.mail_from = config.mail_default_from eduid.workers.msg.init_app(config.celery) diff --git a/src/eduid/common/rpc/msg_relay.py b/src/eduid/common/rpc/msg_relay.py index 9a24b21a0..7c2a0ed12 100644 --- a/src/eduid/common/rpc/msg_relay.py +++ b/src/eduid/common/rpc/msg_relay.py @@ -142,7 +142,7 @@ class MsgRelay: This is the interface to the RPC task to fetch data from NAVET, and to send SMSs. """ - def __init__(self, config: MsgConfigMixin): + def __init__(self, config: MsgConfigMixin) -> None: self.app_name = config.app_name self.conf = config eduid.workers.msg.init_app(config.celery) diff --git a/src/eduid/common/rpc/tests/test_msg_relay.py b/src/eduid/common/rpc/tests/test_msg_relay.py index 65868273a..70bbd9ac3 100644 --- a/src/eduid/common/rpc/tests/test_msg_relay.py +++ b/src/eduid/common/rpc/tests/test_msg_relay.py @@ -33,7 +33,7 @@ def _fix_relations_to(relative_nin: str, relations: Mapping[str, Any]) -> list[d return result @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_all_navet_data()} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -41,7 +41,7 @@ def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock): assert res == NavetData(**self.message_sender.get_devel_all_navet_data()) @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_all_navet_data(identity_number="189001019802")} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -50,7 +50,7 @@ def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock): assert res == NavetData(**self.message_sender.get_devel_all_navet_data(identity_number="189001019802")) @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": None} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -58,7 +58,7 @@ def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMo self.msg_relay.get_all_navet_data(nin="190102031234") @patch("eduid.workers.msg.tasks.get_postal_address.apply_async") - def test_get_postal_address(self, mock_get_postal_address: MagicMock): + def test_get_postal_address(self, mock_get_postal_address: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_postal_address()} ret = Mock(**mock_conf) mock_get_postal_address.return_value = ret @@ -66,7 +66,7 @@ def test_get_postal_address(self, mock_get_postal_address: MagicMock): assert res == FullPostalAddress(**self.message_sender.get_devel_postal_address()) @patch("eduid.workers.msg.tasks.get_postal_address.apply_async") - def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMock): + def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMock) -> None: mock_conf = {"get.return_value": None} ret = Mock(**mock_conf) mock_get_postal_address.return_value = ret @@ -74,7 +74,7 @@ def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMo self.msg_relay.get_postal_address(nin="190102031234") @patch("eduid.workers.msg.tasks.get_relations_to.apply_async") - def test_get_relations_to(self, mock_get_relations: MagicMock): + def test_get_relations_to(self, mock_get_relations: MagicMock) -> None: relations_to = self._fix_relations_to( relative_nin="194004048989", relations=self.message_sender.get_devel_relations() ) @@ -86,7 +86,7 @@ def test_get_relations_to(self, mock_get_relations: MagicMock): assert res == [RelationType(item) for item in relations_to] @patch("eduid.workers.msg.tasks.get_relations_to.apply_async") - def test_get_relations_to_empty_response(self, mock_get_relations: MagicMock): + def test_get_relations_to_empty_response(self, mock_get_relations: MagicMock) -> None: mock_conf: dict[str, Any] = {"get.return_value": []} ret = Mock(**mock_conf) mock_get_relations.return_value = ret diff --git a/src/eduid/common/stats/__init__.py b/src/eduid/common/stats/__init__.py index 25e0cfa4e..987581171 100644 --- a/src/eduid/common/stats/__init__.py +++ b/src/eduid/common/stats/__init__.py @@ -14,6 +14,7 @@ __author__ = "ft" from abc import ABC, abstractmethod +from logging import Logger from eduid.common.config.base import StatsConfigMixin @@ -23,7 +24,7 @@ class AppStats(ABC): def count(self, name: str, value: int = 1) -> None: pass - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: pass @@ -35,7 +36,7 @@ class NoOpStats(AppStats): configured allows us to not check if current_app.stats is set everywhere. """ - def __init__(self, logger=None, prefix=None): + def __init__(self, logger: Logger | None = None, prefix: str | None = None) -> None: self.logger = logger self.prefix = prefix @@ -45,7 +46,7 @@ def count(self, name: str, value: int = 1) -> None: name = f"{self.prefix!s}.{name!s}" self.logger.info(f"No-op stats count: {name!r} {value!r}") - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: if self.logger: if self.prefix: name = f"{self.prefix!s}.{name!s}" @@ -53,7 +54,7 @@ def gauge(self, name: str, value: int, rate=1, delta=False): class Statsd(AppStats): - def __init__(self, host, port, prefix=None): + def __init__(self, host: str, port: int, prefix: str | None = None) -> None: import statsd self.client = statsd.StatsClient(host, port, prefix=prefix) @@ -64,7 +65,7 @@ def count(self, name: str, value: int = 1) -> None: # for .count self.client.incr(f"{name}.count", count=value) - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: self.client.gauge(f"{name}.gauge", value=value, rate=rate, delta=delta) diff --git a/src/eduid/common/testing_base.py b/src/eduid/common/testing_base.py index 9566edaed..562a49ec5 100644 --- a/src/eduid/common/testing_base.py +++ b/src/eduid/common/testing_base.py @@ -4,13 +4,14 @@ import logging.config import os import uuid +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from enum import Enum from typing import Any, TypeVar from bson import ObjectId -from eduid.userdb.testing import MongoTestCase +from eduid.userdb.testing import MongoTestCase, SetupConfig logger = logging.getLogger(__name__) @@ -18,14 +19,14 @@ class CommonTestCase(MongoTestCase): """Base Test case for eduID webapps and workers""" - def setUp(self, *args: Any, **kwargs: Any) -> None: + def setUp(self, config: SetupConfig | None = None) -> None: """ set up tests """ if "EDUID_CONFIG_YAML" not in os.environ: os.environ["EDUID_CONFIG_YAML"] = "YAML_CONFIG_NOT_USED" - super().setUp(*args, **kwargs) + super().setUp(config=config) SomeData = TypeVar("SomeData") @@ -37,7 +38,7 @@ def normalised_data( """Utility function for normalising data before comparisons in test cases.""" class NormaliseEncoder(json.JSONEncoder): - def default(self, o: Any) -> str | Any: + def default(self, o: object) -> Iterable: if isinstance(o, datetime): if replace_datetime is not None: return replace_datetime @@ -63,10 +64,10 @@ def default(self, o: Any) -> str | Any: return repr(o) class NormaliseDecoder(json.JSONDecoder): - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(object_hook=self.object_hook, *args, **kwargs) - def object_hook(self, o: Any) -> dict[str, Any]: + def object_hook(self, o: dict) -> dict[str, Any]: """ Decode any keys ending in _ts to datetime objects. diff --git a/src/eduid/graphdb/db.py b/src/eduid/graphdb/db.py index bb5fbf868..277c79ab2 100644 --- a/src/eduid/graphdb/db.py +++ b/src/eduid/graphdb/db.py @@ -13,7 +13,7 @@ class Neo4jDB: """Simple wrapper to allow us to define the api""" - def __init__(self, db_uri: str, config: Mapping[str, Any] | None = None): + def __init__(self, db_uri: str, config: Mapping[str, Any] | None = None) -> None: if not db_uri: raise ValueError("db_uri not supplied") @@ -69,14 +69,14 @@ def sanitized_uri(self) -> str: def driver(self) -> Driver: return self._driver - def close(self): + def close(self) -> None: self.driver.close() class BaseGraphDB(ABC): """Base class for common db operations""" - def __init__(self, db_uri: str, config: dict[str, Any] | None = None): + def __init__(self, db_uri: str, config: dict[str, Any] | None = None) -> None: self._db_uri = db_uri self._db = Neo4jDB(db_uri=self._db_uri, config=config) self.db_setup() @@ -85,7 +85,7 @@ def __repr__(self) -> str: return f"" @property - def db(self): + def db(self) -> Neo4jDB: return self._db def db_setup(self) -> None: diff --git a/src/eduid/graphdb/groupdb/db.py b/src/eduid/graphdb/groupdb/db.py index c38cf69a3..8ffa4a96b 100644 --- a/src/eduid/graphdb/groupdb/db.py +++ b/src/eduid/graphdb/groupdb/db.py @@ -33,11 +33,11 @@ class Role(enum.Enum): class GroupDB(BaseGraphDB): - def __init__(self, db_uri: str, scope: str, config: dict[str, Any] | None = None): + def __init__(self, db_uri: str, scope: str, config: dict[str, Any] | None = None) -> None: super().__init__(db_uri=db_uri, config=config) self._scope = scope - def db_setup(self): + def db_setup(self) -> None: with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session: # new index creation syntax in neo4j >=5.0 statements = [ @@ -73,7 +73,7 @@ def db_setup(self): logger.info(f"{self} setup done.") @property - def scope(self): + def scope(self) -> str: return self._scope def _create_or_update_group(self, tx: Transaction, group: Group) -> Group: @@ -109,16 +109,28 @@ def _add_or_update_users_and_groups( for user_member in group.member_users: res = self._add_user_to_group(tx, group=group, member=user_member, role=Role.MEMBER) - members.add(User.from_mapping(res.data())) + if res: + members.add(User.from_mapping(res.data())) + else: + logger.info(f"User {user_member.identifier} not added to group {group.identifier}.") for group_member in group.member_groups: res = self._add_group_to_group(tx, group=group, member=group_member, role=Role.MEMBER) - members.add(self._load_group(res.data())) + if res: + members.add(self._load_group(res.data())) + else: + logger.info(f"Group {group_member.identifier} not added to group {group.identifier}.") for user_owner in group.owner_users: res = self._add_user_to_group(tx, group=group, member=user_owner, role=Role.OWNER) - owners.add(User.from_mapping(res.data())) + if res: + owners.add(User.from_mapping(res.data())) + else: + logger.info(f"User {user_owner.identifier} not added to group {group.identifier}.") for group_owner in group.owner_groups: res = self._add_group_to_group(tx, group=group, member=group_owner, role=Role.OWNER) - owners.add(self._load_group(res.data())) + if res: + owners.add(self._load_group(res.data())) + else: + logger.info(f"Group {group_owner.identifier} not added to group {group.identifier}.") return members, owners def _remove_missing_users_and_groups(self, tx: Transaction, group: Group, role: Role) -> None: @@ -140,7 +152,7 @@ def _remove_missing_users_and_groups(self, tx: Transaction, group: Group, role: elif Label.USER.value in record["labels"]: self._remove_user_from_group(tx, group=group, user_identifier=record["identifier"], role=role) - def _remove_group_from_group(self, tx: Transaction, group: Group, group_identifier: str, role: Role): + def _remove_group_from_group(self, tx: Transaction, group: Group, group_identifier: str, role: Role) -> None: q = f""" MATCH (:Group {{scope: $scope, identifier: $identifier}})<-[r:{role.value}]-(:Group {{scope: $scope, identifier: $group_identifier}}) @@ -153,7 +165,7 @@ def _remove_group_from_group(self, tx: Transaction, group: Group, group_identifi group_identifier=group_identifier, ) - def _remove_user_from_group(self, tx: Transaction, group: Group, user_identifier: str, role: Role): + def _remove_user_from_group(self, tx: Transaction, group: Group, user_identifier: str, role: Role) -> None: q = f""" MATCH (:Group {{scope: $scope, identifier: $identifier}})<-[r:{role.value}]-(:User {{identifier: $user_identifier}}) @@ -161,7 +173,7 @@ def _remove_user_from_group(self, tx: Transaction, group: Group, user_identifier """ tx.run(q, scope=self.scope, identifier=group.identifier, user_identifier=user_identifier) - def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Record: + def _add_group_to_group(self, tx: Transaction, group: Group, member: Group, role: Role) -> Record | None: q = f""" MATCH (g:Group {{scope: $scope, identifier: $group_identifier}}) MERGE (m:Group {{scope: $scope, identifier: $identifier}}) @@ -187,7 +199,7 @@ def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Re display_name=member.display_name, ).single() - def _add_user_to_group(self, tx, group: Group, member: User, role: Role) -> Record: + def _add_user_to_group(self, tx: Transaction, group: Group, member: User, role: Role) -> Record | None: q = f""" MATCH (g:Group {{scope: $scope, identifier: $group_identifier}}) MERGE (m:User {{identifier: $identifier}}) @@ -276,7 +288,7 @@ def remove_group(self, identifier: str) -> None: with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session: session.run(q, scope=self.scope, identifier=identifier) - def get_groups_by_property(self, key: str, value: str, skip=0, limit=100): + def get_groups_by_property(self, key: str, value: str, skip: int = 0, limit: int = 100) -> list[Group]: res: list[Group] = [] q = f""" MATCH (g: Group {{scope: $scope}}) @@ -303,7 +315,7 @@ def get_groups(self, skip: int = 0, limit: int = 100) -> list[Group]: ] return res - def _get_groups_for_role(self, label: Label, identifier: str, role: Role): + def _get_groups_for_role(self, label: Label, identifier: str, role: Role) -> list[Group]: res: list[Group] = [] if label == Label.GROUP: entity_match = "(e:Group {scope: $scope, identifier: $identifier})" @@ -365,7 +377,9 @@ def group_exists(self, identifier: str) -> bool: RETURN count(*) as exists LIMIT 1 """ with self.db.driver.session(default_access_mode=READ_ACCESS) as session: - ret = session.run(q, scope=self.scope, identifier=identifier).single()["exists"] + single_value = session.run(q, scope=self.scope, identifier=identifier).single() + assert single_value is not None # please mypy + ret = single_value["exists"] return bool(ret) def save(self, group: Group) -> Group: diff --git a/src/eduid/graphdb/testing.py b/src/eduid/graphdb/testing.py index 06586f641..5293af42f 100644 --- a/src/eduid/graphdb/testing.py +++ b/src/eduid/graphdb/testing.py @@ -39,7 +39,7 @@ class Neo4jTemporaryInstance(EduidTemporaryInstance): DEFAULT_USERNAME = "neo4j" DEFAULT_PASSWORD = "testingtesting" - def __init__(self, max_retry_seconds: int = 60, neo4j_version: str = NEO4J_VERSION): + def __init__(self, max_retry_seconds: int = 60, neo4j_version: str = NEO4J_VERSION) -> None: self._http_port = random.randint(40000, 43000) self._https_port = random.randint(44000, 46000) self._bolt_port = random.randint(47000, 50000) @@ -89,22 +89,22 @@ def conn(self) -> Neo4jDB: return self._conn @property - def host(self): + def host(self) -> str: return self._host @property - def http_port(self): + def http_port(self) -> int: return self._http_port @property - def https_port(self): + def https_port(self) -> int: return self._https_port @property - def bolt_port(self): + def bolt_port(self) -> int: return self._bolt_port - def purge_db(self): + def purge_db(self) -> None: q = """ MATCH (n) DETACH DELETE n @@ -135,5 +135,5 @@ def setUpClass(cls) -> None: cls.neo4j_instance = Neo4jTemporaryInstance.get_instance(max_retry_seconds=60) cls.neo4jdb = cls.neo4j_instance.conn - def tearDown(self): + def tearDown(self) -> None: self.neo4j_instance.purge_db() diff --git a/src/eduid/graphdb/tests/test_db.py b/src/eduid/graphdb/tests/test_db.py index 11f4669b0..fc98c4a8d 100644 --- a/src/eduid/graphdb/tests/test_db.py +++ b/src/eduid/graphdb/tests/test_db.py @@ -1,3 +1,5 @@ +from typing import Any + from neo4j import basic_auth from eduid.graphdb.db import BaseGraphDB @@ -7,25 +9,27 @@ class TestNeo4jDB(Neo4jTestCase): - def test_create_db(self): + def test_create_db(self) -> None: with self.neo4jdb.driver.session() as session: session.run("CREATE (n:Test $props)", props={"name": "test node", "testing": True}) with self.neo4jdb.driver.session() as session: result = session.run("MATCH (n {name: $name})RETURN n.testing", name="test node") - self.assertTrue(result.single().value()) + single = result.single() + assert single is not None + self.assertTrue(single.value()) class TestBaseGraphDB(Neo4jTestCase): class TestDB(BaseGraphDB): - def __init__(self, db_uri, config=None): + def __init__(self, db_uri: str, config: dict[str, Any] | None = None) -> None: super().__init__(db_uri, config=config) - def db_setup(self): + def db_setup(self) -> None: with self._db.driver.session() as session: session.run("CREATE CONSTRAINT ON (n:Test) ASSERT n.name IS UNIQUE") session.run("CREATE INDEX FOR (n:Test) ON (n.testing)") - def test_base_db(self): + def test_base_db(self) -> None: db_uri = self.neo4jdb.db_uri config = {"encrypted": False, "auth": basic_auth("neo4j", "testingtesting")} @@ -34,4 +38,6 @@ def test_base_db(self): session.run("CREATE (n:Test $props)", props={"name": "test node", "testing": True}) with test_db._db.driver.session() as session: result = session.run("MATCH (n {name: $name})RETURN n.testing", name="test node") - self.assertTrue(result.single().value()) + single = result.single() + assert single is not None + self.assertTrue(single.value()) diff --git a/src/eduid/graphdb/tests/test_group.py b/src/eduid/graphdb/tests/test_group.py index 2fa36bf17..13c65bb00 100644 --- a/src/eduid/graphdb/tests/test_group.py +++ b/src/eduid/graphdb/tests/test_group.py @@ -1,51 +1,66 @@ +from typing import TypedDict from unittest import TestCase +from typing_extensions import NotRequired + from eduid.graphdb.groupdb import Group, User __author__ = "lundberg" +class GroupData(TypedDict): + identifier: str + display_name: str + members: NotRequired[set[Group | User]] + owners: NotRequired[set[Group | User]] + + +class UserData(TypedDict): + identifier: str + display_name: str + + class TestGroup(TestCase): def setUp(self) -> None: - self.group1: dict[str, str | list] = { + self.group1: GroupData = { "identifier": "test1", "display_name": "Test Group 1", } - self.group2: dict[str, str | list] = { + self.group2: GroupData = { "identifier": "test2", "display_name": "Test Group 2", } - self.user1: dict[str, str] = {"identifier": "user1", "display_name": "Test Testsson"} - self.user2: dict[str, str] = {"identifier": "user2", "display_name": "Namn Namnsson"} + self.user1: UserData = {"identifier": "user1", "display_name": "Test Testsson"} + self.user2: UserData = {"identifier": "user2", "display_name": "Namn Namnsson"} - def test_init_group(self): + def test_init_group(self) -> None: group = Group(**self.group1) assert self.group1["identifier"] == group.identifier assert self.group1["display_name"] == group.display_name - def test_init_group_with_members(self): + def test_init_group_with_members(self) -> None: user = User(**self.user1) - self.group1["members"] = [user] + self.group1["members"] = set([user]) group = Group(**self.group1) assert user in group.members assert user in group.member_users assert 0 == len(group.member_groups) - group.members.append(Group(**self.group2)) + group.members.add(Group(**self.group2)) assert 1 == len(group.member_groups) - group.members.append(User(**self.user2)) + group.members.add(User(**self.user2)) assert 3 == len(group.members) - def test_init_group_with_owner_and_member(self): + def test_init_group_with_owner_and_member(self) -> None: user = User(**self.user1) owner = User(**self.user2) - self.group1["members"] = [user] - self.group1["owners"] = [owner] + self.group1["members"] = set([user]) + self.group1["owners"] = set([owner]) group = Group(**self.group1) assert user in group.members assert user in group.member_users assert owner in group.owners - def test_get_users_and_groups(self): + def test_get_users_and_groups(self) -> None: member1 = User(**self.user1) member2 = User(**self.user2) member3 = Group(**self.group2) @@ -53,8 +68,8 @@ def test_get_users_and_groups(self): owner2 = User(**self.user2) owner3 = Group(**self.group2) - self.group1["members"] = [member1, member2, member3] - self.group1["owners"] = [owner1, owner2, owner3] + self.group1["members"] = set([member1, member2, member3]) + self.group1["owners"] = set([owner1, owner2, owner3]) group = Group(**self.group1) assert owner2 == group.get_owner_user(identifier=owner2.identifier) diff --git a/src/eduid/graphdb/tests/test_groupdb.py b/src/eduid/graphdb/tests/test_groupdb.py index 1fc713d21..83e82c800 100644 --- a/src/eduid/graphdb/tests/test_groupdb.py +++ b/src/eduid/graphdb/tests/test_groupdb.py @@ -29,7 +29,7 @@ def setUp(self) -> None: self.user2: dict[str, str] = {"identifier": "user2", "display_name": "Namn Namnsson"} @staticmethod - def _assert_group(expected: Group, testing: Group, modified=False): + def _assert_group(expected: Group, testing: Group, modified: bool = False) -> None: assert expected.identifier == testing.identifier assert expected.display_name == testing.display_name assert testing.created_ts is not None @@ -40,20 +40,21 @@ def _assert_group(expected: Group, testing: Group, modified=False): assert testing.modified_ts is None @staticmethod - def _assert_user(expected: User, testing: User): + def _assert_user(expected: User, testing: User) -> None: assert expected.identifier == testing.identifier assert expected.display_name == testing.display_name - def test_create_group(self): + def test_create_group(self) -> None: group = Group.from_mapping(self.group1) post_save_group = self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") self._assert_group(group, post_save_group) get_group = self.group_db.get_group(identifier="test1") + assert isinstance(get_group, Group) self._assert_group(group, get_group) - def test_update_group(self): + def test_update_group(self) -> None: group = Group.from_mapping(self.group1) post_save_group = self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") @@ -65,7 +66,7 @@ def test_update_group(self): assert 1 == self.group_db.db.count_nodes(label="Group") self._assert_group(group, post_save_group2, modified=True) - def test_get_group_by_property(self): + def test_get_group_by_property(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) @@ -73,25 +74,25 @@ def test_get_group_by_property(self): assert 1 == len(post_get_group) self._assert_group(group, post_get_group[0]) - def test_get_non_existing_group(self): + def test_get_non_existing_group(self) -> None: group = self.group_db.get_group(identifier="test1") self.assertIsNone(group) - def test_group_exists(self): + def test_group_exists(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) self.assertTrue(self.group_db.group_exists(identifier=group.identifier)) self.assertFalse(self.group_db.group_exists(identifier="wrong-identifier")) - def test_get_groups(self): + def test_get_groups(self) -> None: self.group_db.save(Group.from_mapping(self.group1)) self.group_db.save(Group.from_mapping(self.group2)) groups = self.group_db.get_groups() assert 2 == len(groups) - def test_save_with_wrong_group_version(self): + def test_save_with_wrong_group_version(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) group = replace(group, display_name="Another display name") @@ -100,7 +101,7 @@ def test_save_with_wrong_group_version(self): self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") - def test_create_group_with_user_member(self): + def test_create_group_with_user_member(self) -> None: group = Group.from_mapping(self.group1) user = User.from_mapping(self.user1) group.members.add(user) @@ -111,10 +112,11 @@ def test_create_group_with_user_member(self): assert 1 == self.group_db.db.count_nodes(label="User") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group is not None post_save_user = post_save_group.member_users[0] self._assert_user(user, post_save_user) - def test_create_group_with_group_member(self): + def test_create_group_with_group_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -124,10 +126,11 @@ def test_create_group_with_group_member(self): assert 2 == self.group_db.db.count_nodes(label="Group") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group is not None post_save_member_group = post_save_group.member_groups[0] self._assert_group(member_group, post_save_member_group) - def test_create_group_with_group_member_and_user_owner(self): + def test_create_group_with_group_member_and_user_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -143,6 +146,7 @@ def test_create_group_with_group_member_and_user_owner(self): assert 2 == self.group_db.db.count_nodes(label="Group") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group is not None post_save_member_group = post_save_group.member_groups[0] self._assert_group(member_group, post_save_member_group) @@ -150,9 +154,10 @@ def test_create_group_with_group_member_and_user_owner(self): self._assert_user(member_user, post_save_user) post_save_owner = post_save_group.owners.pop() + assert isinstance(post_save_owner, User) self._assert_user(owner, post_save_owner) - def test_remove_group(self): + def test_remove_group(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) assert self.group_db.group_exists(group.identifier) is True @@ -160,7 +165,7 @@ def test_remove_group(self): self.group_db.remove_group(group.identifier) assert self.group_db.group_exists(group.identifier) is False - def test_get_groups_for_user_member(self): + def test_get_groups_for_user_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -180,11 +185,17 @@ def test_get_groups_for_user_member(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(group.owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 1 == len(groups[0].members) - self._assert_user(member_user, groups[0].members.pop()) + groups_0_members = groups[0].members.pop() + assert isinstance(groups_0_members, User) + self._assert_user(member_user, groups_0_members) - def test_get_groups_for_user_member_2(self): + def test_get_groups_for_user_member_2(self) -> None: group1 = Group.from_mapping(self.group1) group2 = Group.from_mapping(self.group2) member_user = User.from_mapping(self.user1) @@ -202,7 +213,7 @@ def test_get_groups_for_user_member_2(self): assert sorted([group1.display_name, group2.display_name]) == sorted([x.display_name for x in groups]) assert groups[0].created_ts is not None - def test_get_groups_for_group_member(self): + def test_get_groups_for_group_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -222,11 +233,17 @@ def test_get_groups_for_group_member(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(group.owners) - self._assert_user(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 1 == len(groups[0].members) - self._assert_group(member_group, groups[0].members.pop()) + groups_0_members = groups[0].members.pop() + assert isinstance(groups_0_members, Group) + self._assert_group(member_group, groups_0_members) - def test_get_groups_for_user_owner(self): + def test_get_groups_for_user_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -246,10 +263,14 @@ def test_get_groups_for_user_owner(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(groups[0].owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 2 == len(groups[0].members) - def test_get_groups_for_user_owner_2(self): + def test_get_groups_for_user_owner_2(self) -> None: group1 = Group.from_mapping(self.group1) group2 = Group.from_mapping(self.group2) owner_user = User.from_mapping(self.user1) @@ -272,7 +293,7 @@ def test_get_groups_for_user_owner_2(self): assert 1 == len(groups[0].members) assert 1 == len(groups[1].members) - def test_get_groups_for_group_owner(self): + def test_get_groups_for_group_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -292,10 +313,14 @@ def test_get_groups_for_group_owner(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(groups[0].owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, Group) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, Group) + self._assert_group(group_owners, groups_0_owners) assert 2 == len(groups[0].members) - def test_get_groups_and_users_by_role(self): + def test_get_groups_and_users_by_role(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -324,7 +349,7 @@ def test_get_groups_and_users_by_role(self): assert owner_user.identifier in [owner.identifier for owner in owners] assert 2 == len(owners) - def test_remove_user_from_group(self): + def test_remove_user_from_group(self) -> None: group = Group.from_mapping(self.group1) member_user1 = User.from_mapping(self.user1) member_user2 = User.from_mapping(self.user2) @@ -347,10 +372,11 @@ def test_remove_user_from_group(self): assert post_remove_group.has_member(member_user2.identifier) is True get_group = self.group_db.get_group(identifier="test1") + assert get_group is not None assert get_group.has_member(member_user1.identifier) is False assert get_group.has_member(member_user2.identifier) is True - def test_remove_group_from_group(self): + def test_remove_group_from_group(self) -> None: group = Group.from_mapping(self.group1) member_user1 = User.from_mapping(self.user1) member_group1 = Group.from_mapping(self.group2) @@ -373,5 +399,6 @@ def test_remove_group_from_group(self): assert post_remove_group.has_member(member_user1.identifier) is True get_group = self.group_db.get_group(identifier=group.identifier) + assert get_group is not None assert get_group.has_member(member_group1.identifier) is False assert get_group.has_member(member_user1.identifier) is True diff --git a/src/eduid/maccapi/app.py b/src/eduid/maccapi/app.py index d42e52e4c..f6ee78fff 100644 --- a/src/eduid/maccapi/app.py +++ b/src/eduid/maccapi/app.py @@ -12,7 +12,9 @@ class MAccAPI(FastAPI): - def __init__(self, name: str = "maccapi", test_config: dict | None = None, vccs_client: VCCSClient | None = None): + def __init__( + self, name: str = "maccapi", test_config: dict | None = None, vccs_client: VCCSClient | None = None + ) -> None: self.config = load_config(typ=MAccApiConfig, app_name=name, ns="api", test_config=test_config) super().__init__(root_path=self.config.application_root) self.context = Context(config=self.config, vccs_client=vccs_client) diff --git a/src/eduid/maccapi/config.py b/src/eduid/maccapi/config.py index 908fff0c9..e4f612c8c 100644 --- a/src/eduid/maccapi/config.py +++ b/src/eduid/maccapi/config.py @@ -24,7 +24,7 @@ class MAccApiConfig(AuthnBearerTokenConfig, LoggingConfigMixin, StatsConfigMixin @field_validator("application_root") @classmethod - def application_root_must_not_end_with_slash(cls, v: str): + def application_root_must_not_end_with_slash(cls, v: str) -> str: if v.endswith("/"): logger.warning(f"application_root should not end with slash ({v})") v = removesuffix(v, "/") diff --git a/src/eduid/maccapi/context.py b/src/eduid/maccapi/context.py index 82a1fc925..093cf1d72 100644 --- a/src/eduid/maccapi/context.py +++ b/src/eduid/maccapi/context.py @@ -10,7 +10,7 @@ class Context: - def __init__(self, config: MAccApiConfig, vccs_client: VCCSClient | None = None): + def __init__(self, config: MAccApiConfig, vccs_client: VCCSClient | None = None) -> None: self.name = config.app_name self.config = config diff --git a/src/eduid/maccapi/helpers.py b/src/eduid/maccapi/helpers.py index db955ce9c..de95e291b 100644 --- a/src/eduid/maccapi/helpers.py +++ b/src/eduid/maccapi/helpers.py @@ -20,7 +20,7 @@ class UnableToAddPassword(Exception): pass -def list_users(context: Context, data_owner: str): +def list_users(context: Context, data_owner: str) -> list[ManagedAccount]: managed_accounts: list[ManagedAccount] = context.db.get_users(data_owner=data_owner) context.logger.info(f"Listing {managed_accounts.__len__()} users") return managed_accounts @@ -44,7 +44,7 @@ def add_password(context: Context, managed_account: ManagedAccount, password: st return True -def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: str): +def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: str) -> bool: vccs = context.vccs_client revoke_factors = [] @@ -66,7 +66,7 @@ def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: return True -def save_and_sync_user(context: Context, managed_account: ManagedAccount): +def save_and_sync_user(context: Context, managed_account: ManagedAccount) -> None: context.logger.debug(f"Saving and syncing user {managed_account}") result = context.db.save(managed_account) context.logger.debug(f"Saved user {managed_account} with result {result}") @@ -123,7 +123,7 @@ def deactivate_user(context: Context, eppn: str, data_owner: str) -> ManagedAcco return managed_account -def replace_password(context: Context, eppn: str, new_password: str): +def replace_password(context: Context, eppn: str, new_password: str) -> None: managed_account: ManagedAccount = context.db.get_user_by_eppn(eppn) if managed_account is None: raise UserDoesNotExist(f"User {eppn} not found") @@ -144,7 +144,7 @@ def get_user(context: Context, eppn: str, data_owner: str) -> ManagedAccount: return managed_account -def add_api_event(context: Context, eppn: str, action: str, action_by: str, data_owner: str): +def add_api_event(context: Context, eppn: str, action: str, action_by: str, data_owner: str) -> None: expiration: datetime = utc_now() + timedelta(days=context.config.log_retention_days) log_element = ManagedAccountLogElement( eppn=eppn, diff --git a/src/eduid/maccapi/middleware.py b/src/eduid/maccapi/middleware.py index 4595b38c7..e6ac2f38b 100644 --- a/src/eduid/maccapi/middleware.py +++ b/src/eduid/maccapi/middleware.py @@ -7,6 +7,7 @@ from jwcrypto.common import JWException from pydantic import ValidationError from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp from eduid.common.fastapi.context_request import ContextRequestMixin from eduid.common.models.bearer_token import AuthnBearerToken, RequestedAccessDenied @@ -22,7 +23,7 @@ def return_error_response(status_code: int, detail: str) -> JSONResponse: class AuthenticationMiddleware(BaseHTTPMiddleware, ContextRequestMixin): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context) -> None: super().__init__(app) self.context = context @@ -90,6 +91,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - self.context.logger.error(f"Data owner {repr(data_owner)} not configured") return return_error_response(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown data_owner") + assert isinstance(request.context, MaccAPIContext) # please mypy request.context.data_owner = data_owner request.context.manager_eppn = token.saml_eppn diff --git a/src/eduid/maccapi/routers/status.py b/src/eduid/maccapi/routers/status.py index 2b49c3580..4d01cfd2c 100644 --- a/src/eduid/maccapi/routers/status.py +++ b/src/eduid/maccapi/routers/status.py @@ -24,7 +24,7 @@ class StatusResponse(BaseModel): reason: str -def check_mongo(request: ContextRequest): +def check_mongo(request: ContextRequest) -> bool | None: db = request.app.context.db try: db.is_healthy() diff --git a/src/eduid/maccapi/routers/users.py b/src/eduid/maccapi/routers/users.py index 5642a9dff..bd72c220f 100644 --- a/src/eduid/maccapi/routers/users.py +++ b/src/eduid/maccapi/routers/users.py @@ -2,7 +2,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.utils import generate_password -from eduid.maccapi.context_request import MaccAPIRoute +from eduid.maccapi.context_request import MaccAPIContext, MaccAPIRoute from eduid.maccapi.helpers import ( UnableToAddPassword, add_api_event, @@ -35,9 +35,11 @@ async def get_users(request: ContextRequest) -> UserListResponse: return all users that the calling user has access to in current context """ - manages_accounts = list_users(context=request.app.context, data_owner=request.context.data_owner) + assert isinstance(request.context, MaccAPIContext) # please mypy + assert request.context.data_owner is not None # please mypy + managed_accounts = list_users(context=request.app.context, data_owner=request.context.data_owner) - users = [ApiUser(eppn=user.eppn, given_name=user.given_name, surname=user.surname) for user in manages_accounts] + users = [ApiUser(eppn=user.eppn, given_name=user.given_name, surname=user.surname) for user in managed_accounts] response = UserListResponse(status="success", scope=request.app.context.config.default_eppn_scope, users=users) @@ -56,6 +58,9 @@ async def add_user( password = generate_password() presentable_password = make_presentable_password(password) + assert isinstance(request.context, MaccAPIContext) # please mypy + assert request.context.data_owner is not None # please mypy + assert request.context.manager_eppn is not None # please mypy managed_account: ManagedAccount = create_and_sync_user( context=request.app.context, data_owner=request.context.data_owner, @@ -101,6 +106,9 @@ async def remove_user( request.app.context.logger.debug(f"remove_user: {remove_request}") try: + assert isinstance(request.context, MaccAPIContext) # please mypy + assert request.context.data_owner is not None # please mypy + assert request.context.manager_eppn is not None # please mypy managed_account: ManagedAccount = deactivate_user( context=request.app.context, eppn=remove_request.eppn, data_owner=request.context.data_owner ) @@ -145,6 +153,9 @@ async def reset_password( new_password = generate_password() presentable_password = make_presentable_password(new_password) try: + assert isinstance(request.context, MaccAPIContext) # please mypy + assert request.context.data_owner is not None # please mypy + assert request.context.manager_eppn is not None # please mypy managed_account = get_user(context=request.app.context, eppn=eppn, data_owner=request.context.data_owner) replace_password(context=request.app.context, eppn=eppn, new_password=new_password) diff --git a/src/eduid/maccapi/testing.py b/src/eduid/maccapi/testing.py index 69471014f..6e9deb5d4 100644 --- a/src/eduid/maccapi/testing.py +++ b/src/eduid/maccapi/testing.py @@ -23,7 +23,7 @@ class BaseDBTestCase(unittest.TestCase): mongo_uri: str @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.mongodb_instance = MongoTemporaryInstance.get_instance() cls.mongo_uri = cls.mongodb_instance.uri diff --git a/src/eduid/maccapi/tests/test_maccapi.py b/src/eduid/maccapi/tests/test_maccapi.py index 79c769938..5f8322bff 100644 --- a/src/eduid/maccapi/tests/test_maccapi.py +++ b/src/eduid/maccapi/tests/test_maccapi.py @@ -32,7 +32,7 @@ def _make_bearer_token(self, claims: Mapping[str, Any]) -> str: def _is_presentable_format(self, password: str) -> bool: return len(password) == 14 and password[4] == " " and password[9] == " " - def test_create_user(self): + def test_create_user(self) -> None: domain = "eduid.se" claims = { "saml_eppn": "test@eduid.se", @@ -60,7 +60,7 @@ def test_create_user(self): assert payload["user"]["eppn"] is not None assert payload["user"]["password"] is not None - def test_create_multiple_users(self): + def test_create_multiple_users(self) -> None: claims = { "saml_eppn": "test@eduid.se", "version": 1, @@ -95,7 +95,7 @@ def test_create_multiple_users(self): assert payload["status"] == "success" assert len(payload["users"]) == 2 - def test_remove_user(self): + def test_remove_user(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -117,7 +117,7 @@ def test_remove_user(self): assert payload["user"]["given_name"] == self.user1["given_name"] assert payload["user"]["surname"] == self.user1["surname"] - def test_reset_password(self): + def test_reset_password(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -145,7 +145,7 @@ def test_reset_password(self): new_password = payload["user"]["password"] assert self._is_presentable_format(new_password) - def test_remove_error(self): + def test_remove_error(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -154,7 +154,7 @@ def test_remove_error(self): response = self.client.post(url="/Users/remove", json={"eppn": "made_up"}, headers=headers) assert response.status_code == 422 - def test_reset_error(self): + def test_reset_error(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers diff --git a/src/eduid/queue/db/client.py b/src/eduid/queue/db/client.py index 82034a7dc..40fd3e64f 100644 --- a/src/eduid/queue/db/client.py +++ b/src/eduid/queue/db/client.py @@ -1,5 +1,6 @@ import logging from dataclasses import replace +from typing import Any from bson import ObjectId @@ -15,7 +16,7 @@ class QueuePayloadMixin: - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.handlers: dict[str, type[Payload]] = dict() def register_handler(self, payload: type[Payload]) -> None: @@ -34,7 +35,7 @@ def _load_payload(self, item: QueueItem) -> Payload: class QueueDB(BaseDB, QueuePayloadMixin): - def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_queue"): + def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_queue") -> None: super().__init__(db_uri=db_uri, db_name=db_name, collection=collection) self.handlers: dict[str, type[Payload]] = dict() diff --git a/src/eduid/queue/db/message/db.py b/src/eduid/queue/db/message/db.py index ec780e2b2..b640df20e 100644 --- a/src/eduid/queue/db/message/db.py +++ b/src/eduid/queue/db/message/db.py @@ -4,10 +4,10 @@ class TestDB(QueueDB): - def __init__(self, db_uri: str, collection: str = "test"): + def __init__(self, db_uri: str, collection: str = "test") -> None: super().__init__(db_uri, collection=collection) class MessageDB(QueueDB): - def __init__(self, db_uri: str, collection: str = "message"): + def __init__(self, db_uri: str, collection: str = "message") -> None: super().__init__(db_uri, collection=collection) diff --git a/src/eduid/queue/db/message/payload.py b/src/eduid/queue/db/message/payload.py index 1cb5a4a74..892c74876 100644 --- a/src/eduid/queue/db/message/payload.py +++ b/src/eduid/queue/db/message/payload.py @@ -12,7 +12,7 @@ class EduidTestPayload(Payload): counter: int @classmethod - def from_dict(cls, data: Mapping): + def from_dict(cls, data: Mapping) -> "EduidTestPayload": return cls(**data) @@ -27,7 +27,7 @@ class EduidTestResultPayload(Payload): per_second: int @classmethod - def from_dict(cls, data: Mapping): + def from_dict(cls, data: Mapping) -> "EduidTestResultPayload": return cls(**data) @@ -38,7 +38,7 @@ class EduidSCIMAPINotification(Payload): message: str @classmethod - def from_dict(cls, data: Mapping): + def from_dict(cls, data: Mapping) -> "EduidSCIMAPINotification": data = dict(data) # Do not change caller data return cls(**data) @@ -50,7 +50,7 @@ class EmailPayload(Payload): language: str @classmethod - def from_dict(cls, data: Mapping): + def from_dict(cls, data: Mapping) -> "EmailPayload": data = dict(data) # Do not change caller data return cls(**data) diff --git a/src/eduid/queue/db/payload.py b/src/eduid/queue/db/payload.py index b26cd67f1..01508ad75 100644 --- a/src/eduid/queue/db/payload.py +++ b/src/eduid/queue/db/payload.py @@ -31,7 +31,7 @@ def to_dict(self) -> dict[str, Any]: return self.data @classmethod - def from_dict(cls, data: Mapping[str, Any]): + def from_dict(cls, data: Mapping[str, Any]) -> "RawPayload": data = dict(data) # Do not change caller data return cls(data=data) @@ -43,6 +43,6 @@ class TestPayload(Payload): version: int = 1 @classmethod - def from_dict(cls, data: Mapping[str, Any]): + def from_dict(cls, data: Mapping[str, Any]) -> "TestPayload": data = dict(data) # Do not change caller data return cls(**data) diff --git a/src/eduid/queue/db/queue_item.py b/src/eduid/queue/db/queue_item.py index 1779ee17d..81f562edf 100644 --- a/src/eduid/queue/db/queue_item.py +++ b/src/eduid/queue/db/queue_item.py @@ -24,7 +24,7 @@ class SenderInfo: node_id: str # Should be something like application@system_hostname ex. scimapi@apps-lla-3 @classmethod - def from_dict(cls, data: Mapping[str, Any]): + def from_dict(cls, data: Mapping[str, Any]) -> "SenderInfo": data = dict(data) return cls(**data) @@ -50,7 +50,7 @@ def to_dict(self) -> TUserDbDocument: return TUserDbDocument(res) @classmethod - def from_dict(cls, data: Mapping[str, Any]): + def from_dict(cls, data: Mapping[str, Any]) -> "QueueItem": data = dict(data) item_id = data.pop("_id") processed_by = data.pop("processed_by", None) diff --git a/src/eduid/queue/db/worker.py b/src/eduid/queue/db/worker.py index a4e7357e0..1ccb0db73 100644 --- a/src/eduid/queue/db/worker.py +++ b/src/eduid/queue/db/worker.py @@ -18,7 +18,7 @@ class AsyncQueueDB(AsyncBaseDB, QueuePayloadMixin): - def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_queue"): + def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_queue") -> None: super().__init__(db_uri, collection=collection, db_name=db_name) self.handlers: dict[str, type[Payload]] = dict() @@ -33,14 +33,14 @@ async def create(cls, db_uri: str, collection: str, db_name: str = "eduid_queue" await instance.setup_indexes(indexes) return instance - def parse_queue_item(self, doc: Mapping, parse_payload: bool = True): + def parse_queue_item(self, doc: Mapping, parse_payload: bool = True) -> QueueItem: item = QueueItem.from_dict(doc) if parse_payload is False: # Return the item with the generic RawPayload return item return replace(item, payload=self._load_payload(item)) - async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab=False) -> QueueItem | None: + async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab: bool = False) -> QueueItem | None: """ :param item_id: document id :param worker_name: current workers name diff --git a/src/eduid/queue/decorators.py b/src/eduid/queue/decorators.py index a2a8eafa6..d6acaf9c1 100644 --- a/src/eduid/queue/decorators.py +++ b/src/eduid/queue/decorators.py @@ -1,26 +1,31 @@ +from collections.abc import Callable from inspect import isclass +from typing import Any + +from pymongo.synchronous.collection import Collection from eduid.userdb.db import MongoDB # TODO: Refactor but keep transaction audit document structure +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.util import utc_now class TransactionAudit: enabled = False - def __init__(self, db_uri, db_name="eduid_queue", collection_name="transaction_audit"): - self._conn = None - self.db_uri = db_uri - self.db_name = db_name - self.collection_name = collection_name - self.collection = None + def __init__(self, db_uri: str, db_name: str = "eduid_queue", collection_name: str = "transaction_audit") -> None: + self._conn: MongoDB | None = None + self.db_uri: str = db_uri + self.db_name: str = db_name + self.collection_name: str = collection_name + self.collection: Collection[TUserDbDocument] | None = None - def __call__(self, f): + def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f - def audit(*args, **kwargs): + def audit(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 ret = f(*args, **kwargs) if not isclass(ret) and self.collection: # we can't save class objects in mongodb date = utc_now() @@ -39,15 +44,15 @@ def audit(*args, **kwargs): return audit @classmethod - def enable(cls): + def enable(cls) -> None: cls.enabled = True @classmethod - def disable(cls): + def disable(cls) -> None: cls.enabled = False @staticmethod - def _filter(func, data, *args, **kwargs): + def _filter(func: str, data: object, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 if data is False: return data if func == "_get_navet_data": diff --git a/src/eduid/queue/helpers.py b/src/eduid/queue/helpers.py index 76ad1c4c6..cfa3c04c3 100644 --- a/src/eduid/queue/helpers.py +++ b/src/eduid/queue/helpers.py @@ -18,7 +18,7 @@ class Jinja2Env: Initiates Jinja2 environment with Babel translations """ - def __init__(self): + def __init__(self) -> None: templates_dir = Path(__file__).with_name("templates") translations_dir = Path(__file__).with_name("translations") # Templates diff --git a/src/eduid/queue/testing.py b/src/eduid/queue/testing.py index 7c68c66eb..3584c46c5 100644 --- a/src/eduid/queue/testing.py +++ b/src/eduid/queue/testing.py @@ -34,7 +34,7 @@ class MongoTemporaryInstanceReplicaSet(MongoTemporaryInstance): rs_initialized = False - def __init__(self, max_retry_seconds: int): + def __init__(self, max_retry_seconds: int) -> None: super().__init__(max_retry_seconds=max_retry_seconds) @property @@ -87,12 +87,12 @@ def get_instance( return cast(MongoTemporaryInstanceReplicaSet, super().get_instance(max_retry_seconds=max_retry_seconds)) @property - def uri(self): + def uri(self) -> str: return f"mongodb://localhost:{self.port}" class SMPTDFixTemporaryInstance(EduidTemporaryInstance): - def __init__(self, max_retry_seconds: int): + def __init__(self, max_retry_seconds: int) -> None: super().__init__(max_retry_seconds=max_retry_seconds) @property @@ -138,7 +138,7 @@ def setUp(self) -> None: def tearDown(self) -> None: self.client_db._drop_whole_collection() - def _init_db(self): + def _init_db(self) -> None: db_init_try = 0 while True: try: @@ -164,7 +164,7 @@ async def asyncTearDown(self) -> None: if not task.done(): task.cancel() - async def _init_async_db(self): + async def _init_async_db(self) -> None: db_init_try = 0 while True: try: @@ -180,7 +180,7 @@ async def _init_async_db(self): continue @staticmethod - def create_queue_item(expires_at: datetime, discard_at: datetime, payload: Payload): + def create_queue_item(expires_at: datetime, discard_at: datetime, payload: Payload) -> QueueItem: sender_info = SenderInfo(hostname="localhost", node_id="test") return QueueItem( version=1, @@ -191,7 +191,7 @@ def create_queue_item(expires_at: datetime, discard_at: datetime, payload: Paylo payload=payload, ) - async def _assert_item_gets_processed(self, queue_item: QueueItem, retry: bool = False): + async def _assert_item_gets_processed(self, queue_item: QueueItem, retry: bool = False) -> None: end_time = utc_now() + timedelta(seconds=10) fetched: QueueItem | None = None while utc_now() < end_time: @@ -209,7 +209,7 @@ async def _assert_item_gets_processed(self, queue_item: QueueItem, retry: bool = class IsolatedWorkerDBMixin(MixinBase): # override run so we can mock cache of database clients - async def run(self): + async def run(self) -> None: # Init db in the correct loop # Make sure the isolated test cases get to create their own mongodb clients with patch("eduid.userdb.db.async_db.AsyncClientCache._clients", {}): diff --git a/src/eduid/queue/tests/test_client.py b/src/eduid/queue/tests/test_client.py index e1ed03d8b..1f9cdce02 100644 --- a/src/eduid/queue/tests/test_client.py +++ b/src/eduid/queue/tests/test_client.py @@ -12,7 +12,7 @@ class TestClient(TestCase): - def test_queue_item(self): + def test_queue_item(self) -> None: expires_at = utc_now() + timedelta(days=180) discard_at = expires_at + timedelta(days=7) sender_info = SenderInfo(hostname="testhost", node_id="userdb@testhost") @@ -36,7 +36,7 @@ def setUp(self) -> None: self.discard_at = self.expires_at + timedelta(days=7) self.sender_info = SenderInfo(hostname="testhost", node_id="userdb@testhost") - def _create_queue_item(self, payload: Payload): + def _create_queue_item(self, payload: Payload) -> QueueItem: return QueueItem( version=1, expires_at=self.expires_at, @@ -46,7 +46,7 @@ def _create_queue_item(self, payload: Payload): payload=payload, ) - def test_eduid_invite_mail(self): + def test_eduid_invite_mail(self) -> None: payload = EduidInviteEmail( email="mail@example.com", reference="ref_id", @@ -60,7 +60,7 @@ def test_eduid_invite_mail(self): assert normalised_data(item.to_dict()) == normalised_data(loaded_message_dict) assert normalised_data(payload.to_dict()) == normalised_data(item.payload.to_dict()) - def test_eduid_signup_mail(self): + def test_eduid_signup_mail(self) -> None: payload = EduidSignupEmail( email="mail@example.com", reference="ref_id", @@ -75,7 +75,7 @@ def test_eduid_signup_mail(self): class TestMessageDB(EduidQueueTestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.messagedb = MessageDB(self.mongo_uri) self.messagedb.register_handler(TestPayload) @@ -86,11 +86,11 @@ def setUp(self): self.discard_at = self.expires_at + timedelta(days=7) self.sender_info = SenderInfo(hostname="testhost", node_id="userdb@testhost") - def tearDown(self): + def tearDown(self) -> None: super().tearDown() self.messagedb._drop_whole_collection() - def _create_queue_item(self, payload: Payload): + def _create_queue_item(self, payload: Payload) -> QueueItem: return QueueItem( version=1, expires_at=self.expires_at, @@ -100,33 +100,36 @@ def _create_queue_item(self, payload: Payload): payload=payload, ) - def test_save_load(self): + def test_save_load(self) -> None: payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) self.messagedb.save(item) assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert loaded_item.payload_type == payload.get_type() assert isinstance(loaded_item.payload, TestPayload) is True assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) - def test_save_load_raw_payload(self): + def test_save_load_raw_payload(self) -> None: payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) self.messagedb.save(item) assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert loaded_item.payload_type == payload.get_type() assert isinstance(loaded_item.payload, TestPayload) is True raw_loaded_item = self.messagedb.get_item_by_id(item.item_id, parse_payload=False) + assert raw_loaded_item assert raw_loaded_item.payload_type == payload.get_type() assert isinstance(raw_loaded_item.payload, RawPayload) is True assert normalised_data(item.payload.to_dict()) == normalised_data(raw_loaded_item.payload.to_dict()) - def test_save_load_eduid_email_invite(self): + def test_save_load_eduid_email_invite(self) -> None: payload = EduidInviteEmail( email="mail@example.com", reference="ref_id", @@ -140,10 +143,11 @@ def test_save_load_eduid_email_invite(self): assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) assert normalised_data(item.payload.to_dict()), normalised_data(loaded_item.payload.to_dict()) - def test_save_load_eduid_email_signup(self): + def test_save_load_eduid_email_signup(self) -> None: payload = EduidSignupEmail( email="mail@example.com", reference="ref_id", @@ -156,12 +160,13 @@ def test_save_load_eduid_email_signup(self): assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) assert normalised_data(item.payload.to_dict()) == normalised_data(loaded_item.payload.to_dict()) @skip("It takes mongo a couple of seconds to actually remove the document, skip for now.") # TODO: Investigate if it is possible to force a expire check in mongodb - def test_auto_discard(self): + def test_auto_discard(self) -> None: self.discard_at = utc_now() - timedelta(seconds=-10) payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) diff --git a/src/eduid/queue/tests/test_mail_worker.py b/src/eduid/queue/tests/test_mail_worker.py index 023d82a87..184d972a9 100644 --- a/src/eduid/queue/tests/test_mail_worker.py +++ b/src/eduid/queue/tests/test_mail_worker.py @@ -3,7 +3,7 @@ import os from datetime import timedelta from os import environ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from aiosmtplib import SMTPResponse @@ -66,7 +66,7 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: await super().asyncTearDown() - async def test_eduid_signup_mail_from_stream(self): + async def test_eduid_signup_mail_from_stream(self) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -82,7 +82,7 @@ async def test_eduid_signup_mail_from_stream(self): await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail): + async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail: MagicMock) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -99,7 +99,7 @@ async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_send await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail): + async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: MagicMock) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -118,7 +118,7 @@ async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail): self.client_db.save(queue_item) await self._assert_item_gets_processed(queue_item, retry=True) - async def test_register_mail_translations(self): + async def test_register_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidSignupEmail( email="noone@example.com", @@ -143,7 +143,7 @@ async def test_register_mail_translations(self): assert "Subject: eduID-registrering" in msg_string assert "Du har registrerat noone@example.com som e-postadress" in msg_string - async def test_reset_password_mail_translations(self): + async def test_reset_password_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidResetPasswordEmail( email="noone@example.com", @@ -171,7 +171,7 @@ async def test_reset_password_mail_translations(self): assert "Du har bett om att byta" in msg_string assert "giltig i 2 timmar." in msg_string - async def test_verification_mail_translations(self): + async def test_verification_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidVerificationEmail( email="noone@example.com", @@ -197,7 +197,7 @@ async def test_verification_mail_translations(self): assert "Du har nyligen lagt till den" in msg_string assert "Skriv in koden nedan" in msg_string - async def test_termination_mail_translations(self): + async def test_termination_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidTerminationEmail( email="noone@example.com", diff --git a/src/eduid/queue/tests/test_worker.py b/src/eduid/queue/tests/test_worker.py index 0ec1aade5..23c2423bc 100644 --- a/src/eduid/queue/tests/test_worker.py +++ b/src/eduid/queue/tests/test_worker.py @@ -62,7 +62,7 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: await super().asyncTearDown() - async def test_worker_item_from_stream(self): + async def test_worker_item_from_stream(self) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -74,7 +74,7 @@ async def test_worker_item_from_stream(self): self.client_db.save(queue_item) await self._assert_item_gets_processed(queue_item) - async def test_worker_expired_item(self): + async def test_worker_expired_item(self) -> None: """ Test that expired queue items are handled by the handle_expired_item method """ diff --git a/src/eduid/queue/workers/base.py b/src/eduid/queue/workers/base.py index 5b50883f7..371fd0321 100644 --- a/src/eduid/queue/workers/base.py +++ b/src/eduid/queue/workers/base.py @@ -21,13 +21,13 @@ logger = logging.getLogger(__name__) -def cancel_task(signame, task): +def cancel_task(signame: str, task: Task) -> None: logger.info(f"got signal {signame}: exit") task.cancel() class QueueWorker(ABC): - def __init__(self, config: QueueWorkerConfig, handle_payloads: Sequence[type[Payload]]): + def __init__(self, config: QueueWorkerConfig, handle_payloads: Sequence[type[Payload]]) -> None: worker_name = environ.get("WORKER_NAME", None) if worker_name is None: raise RuntimeError("Environment variable WORKER_NAME needs to be set") @@ -47,7 +47,7 @@ def add_task(tasks: set[Task], task: Task) -> set[Task]: tasks.add(task) return tasks - async def run(self): + async def run(self) -> None: # Init db in the correct loop self.db = await AsyncQueueDB.create(db_uri=self.config.mongo_uri, collection=self.config.mongo_collection) # Register payloads to handle @@ -64,7 +64,7 @@ async def run(self): logger.info(f"Running: {main_task.get_name()}") await main_task - async def run_subtasks(self): + async def run_subtasks(self) -> None: logger.info(f"Initiating event stream for: {self.db}") watch_collection_task = asyncio.create_task( self.watch_collection(), name=f"Watch collection {self.config.mongo_collection}" diff --git a/src/eduid/queue/workers/mail.py b/src/eduid/queue/workers/mail.py index d92eb9bd2..57a7ee836 100644 --- a/src/eduid/queue/workers/mail.py +++ b/src/eduid/queue/workers/mail.py @@ -31,7 +31,7 @@ class MailQueueWorker(QueueWorker): - def __init__(self, config: QueueWorkerConfig): + def __init__(self, config: QueueWorkerConfig) -> None: # Register which queue items this worker should try to grab payloads: Sequence[type[Payload]] = [ EduidInviteEmail, @@ -46,7 +46,7 @@ def __init__(self, config: QueueWorkerConfig): self._jinja2 = Jinja2Env() @property - async def smtp(self): + async def smtp(self) -> SMTP: if self._smtp is None: logger.debug(f"Creating SMTP client for {self.config.mail_host}:{self.config.mail_port}") validate_certs = self.config.mail_verify_tls @@ -293,7 +293,7 @@ def init_mail_worker(name: str = "mail_worker", test_config: Mapping[str, Any] | return MailQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_mail_worker() if worker.smtp is None: # fail fast if we can't connect to the SMTP server diff --git a/src/eduid/queue/workers/scim_event.py b/src/eduid/queue/workers/scim_event.py index 31d7fbf4c..098a060cd 100644 --- a/src/eduid/queue/workers/scim_event.py +++ b/src/eduid/queue/workers/scim_event.py @@ -19,7 +19,7 @@ class ScimEventQueueWorker(QueueWorker): - def __init__(self, config: QueueWorkerConfig): + def __init__(self, config: QueueWorkerConfig) -> None: # Register which queue items this worker should try to grab payloads = [EduidSCIMAPINotification] super().__init__(config=config, handle_payloads=payloads) @@ -62,7 +62,7 @@ def init_scim_event_worker( return ScimEventQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_scim_event_worker() exit(asyncio.run(worker.run())) diff --git a/src/eduid/queue/workers/sink.py b/src/eduid/queue/workers/sink.py index 5643b5b5e..78706b331 100644 --- a/src/eduid/queue/workers/sink.py +++ b/src/eduid/queue/workers/sink.py @@ -21,7 +21,7 @@ class SinkQueueWorker(QueueWorker): - def __init__(self, config: QueueWorkerConfig): + def __init__(self, config: QueueWorkerConfig) -> None: # Register which queue items this worker should try to grab payloads = [EduidTestPayload] super().__init__(config=config, handle_payloads=payloads) @@ -91,7 +91,7 @@ def init_sink_worker(name: str = "sink_worker", test_config: Mapping[str, Any] | return SinkQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_sink_worker() exit(asyncio.run(worker.run())) diff --git a/src/eduid/satosa/scimapi/accr.py b/src/eduid/satosa/scimapi/accr.py index 5cd93b92a..5762ba0e1 100644 --- a/src/eduid/satosa/scimapi/accr.py +++ b/src/eduid/satosa/scimapi/accr.py @@ -35,7 +35,9 @@ class request(RequestMicroService): http://id.swedenconnect.se/loa/1.0/uncertified-loa2: http://id.elegnamnden.se/loa/1.0/loa2 """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: self.lowest_accepted_accr_for_virtual_idp: LowestAcceptedACCRForVirtualIdpConfig | None = config.get( "lowest_accepted_accr_for_virtual_idp" ) @@ -111,7 +113,9 @@ class response(ResponseMicroService): name: accrResponse """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) def process(self, context: satosa.context.Context, data: satosa.internal.InternalData) -> ProcessReturnType: diff --git a/src/eduid/satosa/scimapi/pairwiseid.py b/src/eduid/satosa/scimapi/pairwiseid.py index c0a9e4545..e9df64d85 100644 --- a/src/eduid/satosa/scimapi/pairwiseid.py +++ b/src/eduid/satosa/scimapi/pairwiseid.py @@ -36,7 +36,7 @@ def __init__( config: Mapping[str, Any], *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) self.config = Config(**config) logger.info("Loaded pairwise-id generator") diff --git a/src/eduid/satosa/scimapi/scim_attributes.py b/src/eduid/satosa/scimapi/scim_attributes.py index 488f8319d..f766173c5 100644 --- a/src/eduid/satosa/scimapi/scim_attributes.py +++ b/src/eduid/satosa/scimapi/scim_attributes.py @@ -43,7 +43,9 @@ class ScimAttributes(ResponseMicroService): Add attributes from the scim db to the responses. """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.config = Config(**config) diff --git a/src/eduid/satosa/scimapi/serve_static.py b/src/eduid/satosa/scimapi/serve_static.py index 6a5e2bbec..c1a3091bd 100644 --- a/src/eduid/satosa/scimapi/serve_static.py +++ b/src/eduid/satosa/scimapi/serve_static.py @@ -4,9 +4,12 @@ import logging import mimetypes +from typing import Any +from satosa.context import Context from satosa.micro_services.base import RequestMicroService from satosa.response import Response +from satosa.satosa_config import SATOSAConfig logger = logging.getLogger("satosa") @@ -27,7 +30,7 @@ class ServeStatic(RequestMicroService): logprefix = "SERVE_STATIC_SERVICE:" - def __init__(self, config, *args, **kwargs): + def __init__(self, config: SATOSAConfig, *args: Any, **kwargs: Any) -> None: """ :type config: satosa.satosa_config.SATOSAConfig :param config: The SATOSA proxy config @@ -35,7 +38,7 @@ def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) self.locations = config.get("locations", {}) - def register_endpoints(self): + def register_endpoints(self) -> list: url_map = [] for endpoint, path in self.locations.items(): endpoint = endpoint.strip("/") @@ -43,7 +46,7 @@ def register_endpoints(self): url_map.append([f"^{endpoint}/", self._handle]) return url_map - def _handle(self, context): + def _handle(self, context: Context) -> Response: path = context._path endpoint = path.split("/")[0] target = path[len(endpoint) + 1 :] diff --git a/src/eduid/satosa/scimapi/static_attributes.py b/src/eduid/satosa/scimapi/static_attributes.py index 19493e842..742c4e243 100644 --- a/src/eduid/satosa/scimapi/static_attributes.py +++ b/src/eduid/satosa/scimapi/static_attributes.py @@ -49,14 +49,14 @@ class AddStaticAttributesForVirtualIdp(ResponseMicroService): override existing attributes if present. """ - def __init__(self, config: Mapping[str, Any], *args: Any, **kwargs: Any): + def __init__(self, config: Mapping[str, Any], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.static_attributes: StaticAttributesConfig | None = config.get("static_attributes_for_virtual_idp") self.static_appended_attributes: StaticAppendedAttributesConfig | None = config.get( "static_appended_attributes_for_virtual_idp" ) - def _build_static(self, requester: str, vidp: str, existing_attributes: dict): + def _build_static(self, requester: str, vidp: str, existing_attributes: dict) -> dict[str, list[str]]: static_attributes: dict[str, list[str]] = dict() if self.static_attributes: diff --git a/src/eduid/satosa/scimapi/statsd.py b/src/eduid/satosa/scimapi/statsd.py index 3528acdfd..554b38bf3 100644 --- a/src/eduid/satosa/scimapi/statsd.py +++ b/src/eduid/satosa/scimapi/statsd.py @@ -29,7 +29,7 @@ class RequesterCounter(ResponseMicroService): ``` """ - def __init__(self, config: Mapping[str, Any], *args: Any, **kwargs: Any): + def __init__(self, config: Mapping[str, Any], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) statsd_config = StatsConfigMixin(**config) diff --git a/src/eduid/satosa/scimapi/stepup.py b/src/eduid/satosa/scimapi/stepup.py index 4b14fffc4..b41238e90 100644 --- a/src/eduid/satosa/scimapi/stepup.py +++ b/src/eduid/satosa/scimapi/stepup.py @@ -36,7 +36,7 @@ RequestMicroService, ResponseMicroService, ) -from satosa.response import Response +from satosa.response import Response, SeeOther from satosa.saml_util import make_saml_response try: @@ -191,7 +191,9 @@ class StepUp(ResponseMicroService): - [//acs/redirect, 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect'] """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) try: @@ -467,7 +469,7 @@ def _handle_authn_response(self, context: satosa.context.Context, binding: SAMLB raise RuntimeError("Unexpected response type") return res - def _metadata_endpoint(self, context: satosa.context.Context, extra: Any) -> CallbackReturnType: + def _metadata_endpoint(self, context: satosa.context.Context, extra: object) -> CallbackReturnType: metadata_string = create_metadata_string(None, self.sp.config, 4, None, None, None, None, None).decode("utf-8") return Response(metadata_string, content="text/xml") @@ -507,7 +509,9 @@ class AuthnContext(RequestMicroService): It saves the original requested authn context class reference (accr) in the state. """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) try: @@ -615,7 +619,7 @@ class StepupSAMLBackend(SAMLBackend): A SAML backend to request custom authn context class references from IdP:s with certain entity attributes. """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.mfa: MfaConfig | None = None @@ -625,7 +629,7 @@ def __init__(self, *args: Any, **kwargs: Any): raise StepUpError(f"The configuration for this plugin is not valid: {e}") self.mfa = parsed_config.mfa - def authn_request(self, context: satosa.context.Context, entity_id: str): + def authn_request(self, context: satosa.context.Context, entity_id: str) -> SeeOther | Response: logger.debug(f"Processing AuthnRequest with entity id {repr(entity_id)}") if self.mfa and AuthnContext.sp_wants_mfa(context=context): @@ -647,7 +651,9 @@ class RewriteAuthnContextClass(ResponseMicroService): 'normalisation' of the authn context class reference in our MFA configuration. """ - def __init__(self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any): + def __init__( + self, config: Mapping[str, Any], internal_attributes: dict[str, Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.mfa: MfaConfig | None = None diff --git a/src/eduid/scimapi/app.py b/src/eduid/scimapi/app.py index 5809e7eff..02d61cace 100644 --- a/src/eduid/scimapi/app.py +++ b/src/eduid/scimapi/app.py @@ -22,7 +22,7 @@ class ScimAPI(FastAPI): - def __init__(self, name: str = "scimapi", test_config: dict | None = None): + def __init__(self, name: str = "scimapi", test_config: dict | None = None) -> None: self.config = load_config(typ=ScimApiConfig, app_name=name, ns="api", test_config=test_config) super().__init__(root_path=self.config.application_root) self.context = Context(config=self.config) diff --git a/src/eduid/scimapi/config.py b/src/eduid/scimapi/config.py index 784af4116..64b8867d0 100644 --- a/src/eduid/scimapi/config.py +++ b/src/eduid/scimapi/config.py @@ -33,7 +33,7 @@ class ScimApiConfig(AuthnBearerTokenConfig, LoggingConfigMixin, AWSMixin): @field_validator("application_root") @classmethod - def application_root_must_not_end_with_slash(cls, v: str): + def application_root_must_not_end_with_slash(cls, v: str) -> str: if v.endswith("/"): logger.warning(f"application_root should not end with slash ({v})") v = removesuffix(v, "/") diff --git a/src/eduid/scimapi/context.py b/src/eduid/scimapi/context.py index a16b07945..0b93b91b0 100644 --- a/src/eduid/scimapi/context.py +++ b/src/eduid/scimapi/context.py @@ -2,6 +2,7 @@ import logging.config from dataclasses import dataclass, field from datetime import datetime +from typing import Any from uuid import UUID from eduid.common.config.base import DataOwnerConfig, DataOwnerName @@ -31,7 +32,7 @@ class DataOwnerDatabases: class Context: - def __init__(self, config: ScimApiConfig): + def __init__(self, config: ScimApiConfig) -> None: self.name = config.app_name self.config = config @@ -106,7 +107,7 @@ def get_invitedb(self, data_owner: DataOwnerName) -> ScimApiInviteDB | None: def get_eventdb(self, data_owner: DataOwnerName) -> ScimApiEventDB | None: return self._get_data_owner_dbs(data_owner=data_owner).eventdb - def url_for(self, *args) -> str: + def url_for(self, *args: Any) -> str: url = self.base_url for arg in args: url = urlappend(url, f"{arg}") diff --git a/src/eduid/scimapi/exceptions.py b/src/eduid/scimapi/exceptions.py index 562378f2e..f5f049151 100644 --- a/src/eduid/scimapi/exceptions.py +++ b/src/eduid/scimapi/exceptions.py @@ -2,6 +2,7 @@ import logging import uuid +from typing import Any from fastapi import Request, status from fastapi.encoders import jsonable_encoder @@ -69,7 +70,7 @@ def __init__( detail: str | None = None, schemas: list[str] | None = None, scim_type: str | None = None, - ): + ) -> None: if schemas is None: schemas = [SCIMSchema.ERROR.value] @@ -85,33 +86,33 @@ def extra_headers(self) -> dict | None: return self._extra_headers @extra_headers.setter - def extra_headers(self, headers: dict): + def extra_headers(self, headers: dict) -> None: self._extra_headers = headers class BadRequest(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_400_BAD_REQUEST, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Bad Request" class Unauthorized(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Unauthorized request" class NotFound(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_404_NOT_FOUND, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Resource not found" class MethodNotAllowedMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, **kwargs) if not self.error_detail.detail: allowed_methods = kwargs.get("allowed_methods") @@ -119,14 +120,14 @@ def __init__(self, **kwargs): class UnsupportedMediaTypeMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request was made with an unsupported media type" class Conflict(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(status_code=status.HTTP_409_CONFLICT, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request conflicts with the current state" diff --git a/src/eduid/scimapi/middleware.py b/src/eduid/scimapi/middleware.py index b87277abf..55a84dcb8 100644 --- a/src/eduid/scimapi/middleware.py +++ b/src/eduid/scimapi/middleware.py @@ -7,8 +7,8 @@ from jwcrypto.common import JWException from pydantic import ValidationError from starlette.datastructures import URL -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.types import Message +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp, Message from eduid.common.config.base import DataOwnerName from eduid.common.fastapi.context_request import ContextRequestMixin @@ -29,7 +29,7 @@ # Hack to be able to get request body both now and later # https://github.com/encode/starlette/issues/495#issuecomment-513138055 -async def set_body(request: Request, body: bytes): +async def set_body(request: Request, body: bytes) -> None: async def receive() -> Message: return {"type": "http.request", "body": body} @@ -43,16 +43,16 @@ async def get_body(request: Request) -> bytes: class BaseMiddleware(BaseHTTPMiddleware, ContextRequestMixin): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context) -> None: super().__init__(app) self.context = context - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(req) class ScimMiddleware(BaseMiddleware): - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req, context_class=ScimApiContext) self.context.logger.debug(f"process_request: {req.method} {req.url.path}") resp = await call_next(req) @@ -62,7 +62,7 @@ async def dispatch(self, req: Request, call_next) -> Response: class AuthenticationMiddleware(BaseMiddleware): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context) -> None: super().__init__(app, context) self.no_authn_urls = self.context.config.no_authn_urls self.context.logger.debug(f"No auth allow urls: {self.no_authn_urls}") @@ -78,9 +78,11 @@ def _is_no_auth_path(self, url: URL) -> bool: return True return False - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req, context_class=ScimApiContext) + assert isinstance(req.context, ScimApiContext) # please mypy + if self._is_no_auth_path(req.url): return await call_next(req) diff --git a/src/eduid/scimapi/notifications.py b/src/eduid/scimapi/notifications.py index 39e77b2b7..2014b02b1 100644 --- a/src/eduid/scimapi/notifications.py +++ b/src/eduid/scimapi/notifications.py @@ -20,7 +20,7 @@ class NotificationRelay: - def __init__(self, config: ScimApiConfig): + def __init__(self, config: ScimApiConfig) -> None: self.config = config app_name = config.app_name system_hostname = environ.get("SYSTEM_HOSTNAME", "") # Underlying hosts name for containers diff --git a/src/eduid/scimapi/routers/events.py b/src/eduid/scimapi/routers/events.py index 62948a6fd..93d1a6d9b 100644 --- a/src/eduid/scimapi/routers/events.py +++ b/src/eduid/scimapi/routers/events.py @@ -5,7 +5,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import SCIMResourceType from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.models.event import EventCreateRequest, EventResponse from eduid.scimapi.routers.utils.events import db_event_to_response, get_scim_referenced @@ -31,6 +31,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching event {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.eventdb is not None # please mypy db_event = req.context.eventdb.get_event_by_scim_id(scim_id) if not db_event: raise NotFound(detail="Event not found") @@ -91,6 +93,10 @@ async def on_post(req: ContextRequest, resp: Response, create_request: EventCrea _timestamp = create_request.nutid_event_v1.timestamp _expires_at = utc_now() + timedelta(days=1) + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.data_owner is not None # please mypy + assert req.context.eventdb is not None # please mypy + event = ScimApiEvent( resource=ScimApiEventResource( resource_type=create_request.nutid_event_v1.resource.resource_type, diff --git a/src/eduid/scimapi/routers/groups.py b/src/eduid/scimapi/routers/groups.py index 3bd96a365..ae928f77b 100644 --- a/src/eduid/scimapi/routers/groups.py +++ b/src/eduid/scimapi/routers/groups.py @@ -3,7 +3,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SearchRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.models.group import GroupCreateRequest, GroupResponse, GroupUpdateRequest from eduid.scimapi.routers.utils.events import add_api_event @@ -29,6 +29,8 @@ @groups_router.get("/", response_model=ListResponse) async def on_get_all(req: ContextRequest) -> ListResponse: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy db_groups = req.context.groupdb.get_groups() resources = [{"id": str(db_group.scim_id), "displayName": db_group.graph.display_name} for db_group in db_groups] return ListResponse(total_results=len(db_groups), resources=resources) @@ -67,6 +69,8 @@ async def on_get_one(req: ContextRequest, resp: Response, scim_id: str) -> Group """ req.app.context.logger.info(f"Fetching group {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy db_group = req.context.groupdb.get_group_by_scim_id(scim_id) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -135,6 +139,8 @@ async def on_put( raise BadRequest(detail="Id mismatch") req.app.context.logger.info(f"Fetching group {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy db_group = req.context.groupdb.get_group_by_scim_id(str(update_request.id)) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -145,6 +151,7 @@ async def on_put( raise BadRequest(detail="Version mismatch") # Check that members exists in their respective db + assert req.context.userdb is not None # please mypy req.app.context.logger.info("Checking if group and user members exists") for member in update_request.members: if member.is_group: @@ -162,6 +169,7 @@ async def on_put( db_group = req.context.groupdb.get_group_by_scim_id(str(updated_group.scim_id)) assert db_group # please mypy + assert req.context.data_owner is not None # please mypy if changed: add_api_event( context=req.app.context, @@ -209,12 +217,15 @@ async def on_post(req: ContextRequest, resp: Response, create_request: GroupCrea """ req.app.context.logger.info("Creating group") req.app.context.logger.debug(create_request) + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy created_group = req.context.groupdb.create_group(create_request=create_request) # Load the group from the database to ensure results are consistent with subsequent GETs. # For example, timestamps have higher resolution in created_group than after a load. db_group = req.context.groupdb.get_group_by_scim_id(str(created_group.scim_id)) assert db_group # please mypy + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -237,6 +248,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: GroupCrea ) async def on_delete(req: ContextRequest, scim_id: str) -> None: req.app.context.logger.info(f"Deleting group {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy db_group = req.context.groupdb.get_group_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -248,6 +261,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: res = req.context.groupdb.remove_group(db_group) + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/invites.py b/src/eduid/scimapi/routers/invites.py index 9f375a8cb..cc1518352 100644 --- a/src/eduid/scimapi/routers/invites.py +++ b/src/eduid/scimapi/routers/invites.py @@ -6,7 +6,7 @@ from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SearchRequest from eduid.common.models.scim_invite import InviteCreateRequest, InviteResponse, InviteUpdateRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.routers.utils.invites import ( @@ -40,6 +40,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching invite {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.invitedb is not None # please mypy db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id) if not db_invite: raise NotFound(detail="Invite not found") @@ -60,6 +62,8 @@ async def on_put( raise BadRequest(detail="Id mismatch") req.app.context.logger.info(f"Updating invite {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.invitedb is not None # please mypy db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id) if not db_invite: raise NotFound(detail="Invite not found") @@ -78,6 +82,7 @@ async def on_put( db_invite = replace(db_invite, completed=update_request.nutid_invite_v1.completed) invite_changed = True + assert req.context.data_owner is not None # please mypy if invite_changed: save_invite( req=req, @@ -173,6 +178,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: InviteCre if signup_invite.send_email: send_invite_mail(req, signup_invite) + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -190,6 +197,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: InviteCre @invites_router.delete("/{scim_id}", status_code=204, responses={204: {"description": "No Content"}}) async def on_delete(req: ContextRequest, scim_id: str) -> None: req.app.context.logger.info(f"Deleting invite {scim_id}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.invitedb is not None # please mypy db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found invite: {db_invite}") @@ -209,6 +218,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: # Remove scim invite res = req.context.invitedb.remove(db_invite) + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/users.py b/src/eduid/scimapi/routers/users.py index 7aac70c58..54950cd14 100644 --- a/src/eduid/scimapi/routers/users.py +++ b/src/eduid/scimapi/routers/users.py @@ -8,7 +8,7 @@ from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SCIMSchema, SearchRequest from eduid.common.models.scim_user import UserCreateRequest, UserResponse, UserUpdateRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, Conflict, ErrorDetail, MaxRetriesReached, NotFound from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.routers.utils.users import ( @@ -48,6 +48,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching user {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None db_user = req.context.userdb.get_user_by_scim_id(scim_id) if not db_user: raise NotFound(detail="User not found") @@ -64,6 +66,8 @@ async def on_put(req: ContextRequest, resp: Response, update_request: UserUpdate req.app.context.logger.debug(f"{scim_id} != {update_request.id}") raise BadRequest(detail="Id mismatch") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.userdb is not None # please mypy db_user = req.context.userdb.get_user_by_scim_id(scim_id) if not db_user: raise NotFound(detail="User not found") @@ -143,6 +147,7 @@ async def on_put(req: ContextRequest, resp: Response, update_request: UserUpdate req.app.context.logger.debug(f"Core changed: {core_changed}, nutid_changed: {nutid_changed}") if core_changed or nutid_changed: + assert req.context.data_owner is not None # please mypy save_user(req, db_user) add_api_event( context=req.app.context, @@ -238,6 +243,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: UserCreat ) save_user(req, db_user) + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -259,7 +266,9 @@ async def on_post(req: ContextRequest, resp: Response, create_request: UserCreat responses={204: {"description": "No Content"}}, ) async def on_delete(req: ContextRequest, scim_id: str) -> None: + assert isinstance(req.context, ScimApiContext) # please mypy req.app.context.logger.info(f"Deleting user {scim_id}") + assert req.context.userdb is not None # please mypy db_user = req.context.userdb.get_user_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found user: {db_user}") if not db_user: @@ -279,6 +288,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: res = req.context.userdb.remove(db_user) + assert req.context.data_owner is not None # please mypy add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/utils/events.py b/src/eduid/scimapi/routers/utils/events.py index 1a7ecdbc6..2cae6be22 100644 --- a/src/eduid/scimapi/routers/utils/events.py +++ b/src/eduid/scimapi/routers/utils/events.py @@ -8,6 +8,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema, WeakVersion from eduid.common.utils import make_etag, urlappend +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.models.event import EventResponse, NutidEventExtensionV1, NutidEventResource from eduid.userdb.scimapi import EventLevel, EventStatus, ScimApiEvent, ScimApiEventResource, ScimApiResourceBase @@ -20,7 +21,7 @@ __author__ = "lundberg" -def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiEvent): +def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiEvent) -> EventResponse: location = req.app.context.resource_url(SCIMResourceType.EVENT, db_event.scim_id) meta = Meta( location=location, @@ -61,6 +62,10 @@ def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiE def get_scim_referenced(req: ContextRequest, resource: NutidEventResource) -> ScimApiResourceBase | None: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.userdb is not None # please mypy + assert req.context.groupdb is not None # please mypy + assert req.context.invitedb is not None # please mypy if resource.resource_type == SCIMResourceType.USER: return req.context.userdb.get_user_by_scim_id(str(resource.scim_id)) elif resource.resource_type == SCIMResourceType.GROUP: diff --git a/src/eduid/scimapi/routers/utils/groups.py b/src/eduid/scimapi/routers/utils/groups.py index 31a428ae4..480548be0 100644 --- a/src/eduid/scimapi/routers/utils/groups.py +++ b/src/eduid/scimapi/routers/utils/groups.py @@ -7,6 +7,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema from eduid.common.utils import make_etag +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.models.group import GroupMember, GroupResponse, NutidGroupExtensionV1 from eduid.scimapi.search import SearchFilter @@ -69,6 +70,10 @@ def filter_display_name( raise BadRequest(scim_type="invalidFilter", detail="Invalid displayName") req.app.context.logger.debug(f"Searching for group with display name {repr(filter.val)}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy + assert skip is not None # please mypy + assert limit is not None # please mypy groups, count = req.context.groupdb.get_groups_by_property( key="display_name", value=filter.val, skip=skip, limit=limit ) @@ -90,6 +95,8 @@ def filter_lastmodified( _parsed = datetime.fromisoformat(filter.val) except Exception: raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy return req.context.groupdb.get_groups_by_last_modified(operator=filter.op, value=_parsed, skip=skip, limit=limit) @@ -107,6 +114,10 @@ def filter_extensions_data( raise BadRequest(scim_type="invalidFilter", detail="Unsupported extension search key") req.app.context.logger.debug(f"Searching for groups with {filter.attr} {filter.op} {repr(filter.val)}") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy + assert skip is not None # please mypy + assert limit is not None # please mypy groups, count = req.context.groupdb.get_groups_by_property( key=filter.attr, value=filter.val, skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/routers/utils/invites.py b/src/eduid/scimapi/routers/utils/invites.py index 7fc405f23..5cf39bc6f 100644 --- a/src/eduid/scimapi/routers/utils/invites.py +++ b/src/eduid/scimapi/routers/utils/invites.py @@ -14,6 +14,7 @@ from eduid.common.utils import get_short_hash, make_etag from eduid.queue.db import QueueItem, SenderInfo from eduid.queue.db.message import EduidInviteEmail +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.search import SearchFilter from eduid.scimapi.utils import get_unique_hash @@ -27,6 +28,8 @@ def create_signup_invite( req: ContextRequest, create_request: InviteCreateRequest, db_invite: ScimApiInvite ) -> SignupInvite: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.data_owner is not None # please mypy invite_reference = SCIMReference(data_owner=req.context.data_owner, scim_id=db_invite.scim_id) if create_request.nutid_invite_v1.send_email is False: @@ -64,7 +67,9 @@ def create_signup_invite( return signup_invite -def db_invite_to_response(req: Request, resp: Response, db_invite: ScimApiInvite, signup_invite: SignupInvite): +def db_invite_to_response( + req: Request, resp: Response, db_invite: ScimApiInvite, signup_invite: SignupInvite +) -> InviteResponse: location = req.app.context.url_for("Invites", db_invite.scim_id) meta = Meta( location=location, @@ -112,11 +117,13 @@ def db_invite_to_response(req: Request, resp: Response, db_invite: ScimApiInvite return scim_invite -def create_signup_ref(req: ContextRequest, db_invite: ScimApiInvite): +def create_signup_ref(req: ContextRequest, db_invite: ScimApiInvite) -> SCIMReference: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.data_owner is not None # please mypy return SCIMReference(data_owner=req.context.data_owner, scim_id=db_invite.scim_id) -def send_invite_mail(req: ContextRequest, signup_invite: SignupInvite): +def send_invite_mail(req: ContextRequest, signup_invite: SignupInvite) -> bool: try: email = [email.email for email in signup_invite.mail_addresses if email.primary][0] except IndexError: @@ -164,6 +171,8 @@ def save_invite( signup_invite_is_in_database: bool, ) -> None: try: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.invitedb is not None # please mypy req.context.invitedb.save(db_invite) except DuplicateKeyError as e: assert e.details is not None # please mypy @@ -187,6 +196,8 @@ def filter_lastmodified( raise BadRequest(scim_type="invalidFilter", detail="Unsupported operator") if not isinstance(filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.invitedb is not None # please mypy return req.context.invitedb.get_invites_by_last_modified( operator=filter.op, value=datetime.fromisoformat(filter.val), skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/routers/utils/status.py b/src/eduid/scimapi/routers/utils/status.py index 804026152..3815f9f5f 100644 --- a/src/eduid/scimapi/routers/utils/status.py +++ b/src/eduid/scimapi/routers/utils/status.py @@ -4,7 +4,7 @@ __author__ = "lundberg" -def check_mongo(req: ContextRequest, default_data_owner: str): +def check_mongo(req: ContextRequest, default_data_owner: str) -> bool | None: user_db = req.app.context.get_userdb(default_data_owner) group_db = req.app.context.get_groupdb(default_data_owner) try: @@ -18,7 +18,7 @@ def check_mongo(req: ContextRequest, default_data_owner: str): return False -def check_neo4j(req: ContextRequest, default_data_owner: str): +def check_neo4j(req: ContextRequest, default_data_owner: str) -> bool | None: group_db = req.app.context.get_groupdb(default_data_owner) try: # TODO: Implement is_healthy, check if there is a better way for neo4j diff --git a/src/eduid/scimapi/routers/utils/users.py b/src/eduid/scimapi/routers/utils/users.py index c3a68422c..18e06ccec 100644 --- a/src/eduid/scimapi/routers/utils/users.py +++ b/src/eduid/scimapi/routers/utils/users.py @@ -11,6 +11,7 @@ from eduid.common.models.scim_base import Email, Meta, Name, PhoneNumber, SCIMResourceType, SCIMSchema, SearchRequest from eduid.common.models.scim_user import Group, LinkedAccount, NutidUserExtensionV1, Profile, UserResponse from eduid.common.utils import make_etag +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.search import SearchFilter @@ -21,6 +22,8 @@ def get_user_groups(req: ContextRequest, db_user: ScimApiUser) -> list[Group]: """Return the groups for a user formatted as SCIM search sub-resources""" + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy user_groups = req.context.groupdb.get_groups_for_user_identifer(db_user.scim_id) groups = [] for group in user_groups: @@ -33,9 +36,13 @@ def get_user_groups(req: ContextRequest, db_user: ScimApiUser) -> list[Group]: def remove_user_from_all_groups(req: ContextRequest, db_user: ScimApiUser) -> None: """Remove a user from all groups""" # Remove user from groups + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.groupdb is not None # please mypy + assert req.context.data_owner is not None # please mypy for member_group in req.context.groupdb.get_groups_for_user_identifer(db_user.scim_id): # we need to get the full group object to get all the members group = req.context.groupdb.get_group_by_scim_id(str(member_group.scim_id)) + assert group is not None for member in group.graph.members.copy(): if member.identifier == str(db_user.scim_id): req.app.context.logger.debug( @@ -119,6 +126,8 @@ def db_user_to_response(req: ContextRequest, resp: Response, db_user: ScimApiUse def save_user(req: ContextRequest, db_user: ScimApiUser) -> None: try: + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.userdb is not None # please mypy req.context.userdb.save(db_user) except DuplicateKeyError as e: assert e.details is not None # please mypy @@ -127,7 +136,7 @@ def save_user(req: ContextRequest, db_user: ScimApiUser) -> None: raise BadRequest(detail="Duplicated key error") -def acceptable_linked_accounts(value: list[LinkedAccount], environment: EduidEnvironment): +def acceptable_linked_accounts(value: list[LinkedAccount], environment: EduidEnvironment) -> bool: """ Setting linked_accounts through SCIM with limited issuer and value. If we need to support stepup with someone other than eduID this needs to change. @@ -161,6 +170,8 @@ def filter_externalid(req: ContextRequest, search_filter: SearchFilter) -> list[ if not isinstance(search_filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid externalId") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.userdb is not None # please mypy user = req.context.userdb.get_user_by_external_id(search_filter.val) if not user: @@ -176,6 +187,8 @@ def filter_lastmodified( raise BadRequest(scim_type="invalidFilter", detail="Unsupported operator") if not isinstance(search_filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) # please mypy + assert req.context.userdb is not None # please mypy return req.context.userdb.get_users_by_last_modified( operator=search_filter.op, value=datetime.fromisoformat(search_filter.val), skip=skip, limit=limit ) @@ -195,6 +208,8 @@ def filter_profile_data( req.app.context.logger.debug( f"Searching for users with {search_filter.attr} {search_filter.op} {repr(search_filter.val)}" ) + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None users, count = req.context.userdb.get_user_by_profile_data( profile=profile, operator=search_filter.op, key=key, value=search_filter.val, skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/test-scripts/scim-util.py b/src/eduid/scimapi/test-scripts/scim-util.py index 2e38e24f4..246f4a2b4 100755 --- a/src/eduid/scimapi/test-scripts/scim-util.py +++ b/src/eduid/scimapi/test-scripts/scim-util.py @@ -365,7 +365,7 @@ def main(args: Args) -> bool: return True -def _config_logger(args: Args, progname: str): +def _config_logger(args: Args, progname: str) -> None: # This is the root log level level = logging.INFO if args.debug: diff --git a/src/eduid/scimapi/testing.py b/src/eduid/scimapi/testing.py index 2b2e11212..ea0592f1c 100644 --- a/src/eduid/scimapi/testing.py +++ b/src/eduid/scimapi/testing.py @@ -172,7 +172,7 @@ def add_owner_to_group(self, group_identifier: str, user_identifier: str) -> Sci self.groupdb.save(group) return self.groupdb.get_group_by_scim_id(scim_id=group_identifier) - def tearDown(self): + def tearDown(self) -> None: super().tearDown() if self.userdb: self.userdb._drop_whole_collection() @@ -194,9 +194,9 @@ def _assertScimError( schemas: list[str] | None = None, status: int = 400, scim_type: str | None = None, - detail: Any | None = None, + detail: object | None = None, exclude_keys: list[str] | None = None, - ): + ) -> None: if schemas is None: schemas = [SCIMSchema.ERROR.value] self.assertEqual(schemas, json.get("schemas")) @@ -213,7 +213,7 @@ def _assertScimResponseProperties( response: Response, resource: ScimApiGroup | ScimApiUser | ScimApiInvite | ScimApiEvent, expected_schemas: list[str], - ): + ) -> None: if SCIMSchema.NUTID_USER_V1.value in response.json(): # The API can always add this extension to the parsed_response, even if it was not in the request expected_schemas += [SCIMSchema.NUTID_USER_V1.value] @@ -270,7 +270,7 @@ def _assertScimResponseProperties( ) @staticmethod - def _assertName(db_name: ScimApiName, response_name: dict[str, str]): + def _assertName(db_name: ScimApiName, response_name: dict[str, str]) -> None: name_map = [ ("family_name", "familyName"), ("given_name", "givenName"), @@ -286,7 +286,7 @@ def _assertName(db_name: ScimApiName, response_name: dict[str, str]): ), f"{first}:{db_name_dict.get(first)} != {second}:{response_name.get(second)}" @staticmethod - def _assertResponse(response: Response, status_code: int = 200): + def _assertResponse(response: Response, status_code: int = 200) -> None: _detail = None try: if response.json(): diff --git a/src/eduid/scimapi/tests/test_authn.py b/src/eduid/scimapi/tests/test_authn.py index f0f74fdbc..a1e033358 100644 --- a/src/eduid/scimapi/tests/test_authn.py +++ b/src/eduid/scimapi/tests/test_authn.py @@ -150,7 +150,7 @@ def test_requested_access_canonicalization(self) -> None: assert token.scopes == {domain} assert token.requested_access == [RequestedAccess(type=_requested_access_type, scope=domain)] - def test_invalid_requested_access_scope(self): + def test_invalid_requested_access_scope(self) -> None: # test too short domain name with pytest.raises(ValueError) as exc_info: AuthnBearerToken( @@ -160,6 +160,7 @@ def test_invalid_requested_access_scope(self): requested_access=[RequestedAccess(type=self.config.requested_access_type, scope=".se")], auth_source=AuthSource.CONFIG, ) + assert isinstance(exc_info.value, ValidationError) assert normalised_data(exc_info.value.errors(), exclude_keys=["url"]) == normalised_data( [ { @@ -172,7 +173,7 @@ def test_invalid_requested_access_scope(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_requested_access_not_for_us(self): + def test_requested_access_not_for_us(self) -> None: """Test with a 'requested_access' field with the wrong 'type' value.""" domain = "eduid.se" # test no canonization @@ -184,6 +185,7 @@ def test_requested_access_not_for_us(self): requested_access=[RequestedAccess(type="someone else", scope=domain)], auth_source=AuthSource.CONFIG, ) + assert isinstance(exc_info.value, ValidationError) assert normalised_data(exc_info.value.errors(), exclude_keys=["url"]) == normalised_data( [ { @@ -196,7 +198,7 @@ def test_requested_access_not_for_us(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_regular_token(self): + def test_regular_token(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" claims = { @@ -212,7 +214,7 @@ def test_regular_token(self): assert token.auth_source == AuthSource.CONFIG assert token.requested_access == [RequestedAccess(type="scim-api", scope=ScopeName("eduid.se"))] - def test_multiple_access_requests_including_us(self): + def test_multiple_access_requests_including_us(self) -> None: """Test when requested access has multiple requests. Only keep the request for the current resource.""" domain = "eduid.se" token = AuthnBearerToken( @@ -235,7 +237,7 @@ def test_multiple_access_requests_including_us(self): RequestedAccess(type="scim-api", scope=ScopeName(domain)), ] - def test_interaction_token(self): + def test_interaction_token(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" claims = { @@ -251,7 +253,7 @@ def test_interaction_token(self): assert token.auth_source == AuthSource.INTERACTION assert token.requested_access == [RequestedAccess(type="scim-api", scope=ScopeName("eduid.se"))] - def test_regular_token_with_canonisation(self): + def test_regular_token_with_canonisation(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" domain_alias = "eduid.example.edu" @@ -261,7 +263,7 @@ def test_regular_token_with_canonisation(self): token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() == domain - def test_interaction_token_with_canonisation(self): + def test_interaction_token_with_canonisation(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = DataOwnerName("eduid.se") domain_alias = ScopeName("eduid.example.edu") @@ -271,7 +273,7 @@ def test_interaction_token_with_canonisation(self): token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() == domain - def test_regular_token_upper_case(self): + def test_regular_token_upper_case(self) -> None: """ Test the normal case. Login with access granted based on the single scope in the request. Scope provided in upper-case in the request. @@ -283,21 +285,21 @@ def test_regular_token_upper_case(self): assert token.scopes == {domain} assert token.get_data_owner() == domain - def test_unknown_scope(self): + def test_unknown_scope(self) -> None: """Test login with a scope that has no data owner in the configuration.""" domain = "example.org" claims = {"version": 1, "scopes": [domain], "auth_source": "config"} token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() is None - def test_interaction_token_unknown_scope(self): + def test_interaction_token_unknown_scope(self) -> None: """Test login with a scope that has no data owner in the configuration.""" domain = "example.org" claims = {"version": 1, "saml_eppn": f"eppn{domain}", "auth_source": "interaction"} token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() is None - def test_regular_token_multiple_scopes(self): + def test_regular_token_multiple_scopes(self) -> None: """Test the normal case. Login with access granted based on the scope in the request that has a data owner in configuration (one extra scope provided in the request, named 'aaa' so it is checked first - and skipped). """ @@ -432,17 +434,17 @@ def _make_bearer_token(self, claims: Mapping[str, Any]) -> str: token.make_signed_token(jwk) return token.serialize() - def test_get_user_no_authn(self): + def test_get_user_no_authn(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api(db_user) self._assertScimError(response.json(), status=401, detail="No authentication header found") - def test_get_user_bogus_token(self): + def test_get_user_bogus_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api(db_user, bearer_token="not a jws token") self._assertScimError(response.json(), status=401, detail="Bearer token error") - def test_get_user_untrusted_token(self): + def test_get_user_untrusted_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api( @@ -457,7 +459,7 @@ def test_get_user_untrusted_token(self): self._assertScimError(response.json(), status=401, detail="Bearer token error") - def test_get_user_correct_token(self): + def test_get_user_correct_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) claims = {"scopes": ["eduid.se"], "version": 1, "auth_source": "config"} @@ -472,7 +474,7 @@ def test_get_user_correct_token(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_get_user_interaction_token(self): + def test_get_user_interaction_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) db_group = self.add_group_with_member( group_identifier=str(uuid4()), @@ -480,6 +482,7 @@ def test_get_user_interaction_token(self): user_identifier=str(db_user.scim_id), ) + assert self.groupdb is not None claims = { "saml_eppn": "eppn@eduid.se", "version": 1, @@ -505,7 +508,7 @@ def test_get_user_interaction_token(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_get_user_data_owner_not_configured(self): + def test_get_user_data_owner_not_configured(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) claims = {"scopes": ["not_configured.se"], "version": 1, "auth_source": "config"} token = self._make_bearer_token(claims=claims) diff --git a/src/eduid/scimapi/tests/test_context.py b/src/eduid/scimapi/tests/test_context.py index 00ac60835..d88ce8f5e 100644 --- a/src/eduid/scimapi/tests/test_context.py +++ b/src/eduid/scimapi/tests/test_context.py @@ -6,12 +6,12 @@ class TestContext(ScimApiTestCase): - def test_init(self): + def test_init(self) -> None: config = load_config(typ=ScimApiConfig, app_name="scimapi", ns="api", test_config=self.test_config) ctx = Context(config=config) self.assertEqual(ctx.base_url, "http://localhost:8000") - def test_load_many_data_owners(self): + def test_load_many_data_owners(self) -> None: # Add 99 more data owners to the config for i in range(99): self.test_config["data_owners"][f"owner{i}"] = {"db_name": f"owner_{i}"} diff --git a/src/eduid/scimapi/tests/test_groupdb.py b/src/eduid/scimapi/tests/test_groupdb.py index 6a6ce1f36..0b6c8c281 100644 --- a/src/eduid/scimapi/tests/test_groupdb.py +++ b/src/eduid/scimapi/tests/test_groupdb.py @@ -23,8 +23,9 @@ def setUp(self) -> None: for i in range(9): self.add_group(uuid4(), f"Test Group-{i}") - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.groupdb is not None self.groupdb._drop_whole_collection() def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtensions | None = None) -> ScimApiGroup: @@ -37,22 +38,27 @@ def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtension logger.info(f"TEST saved group {group}") return group - def test_full_search(self): + def test_full_search(self) -> None: + assert self.groupdb is not None groups = self.groupdb.get_groups() self.assertEqual(len(groups), 9) - def test_documents_and_count_first_page(self): + def test_documents_and_count_first_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, limit=3) - [logger.info(f"Group {x}") for x in groups] + for x in groups: + logger.info(f"Group {x}") self.assertEqual(len(groups), 3) self.assertEqual(count, 9) - def test_documents_and_count_last_page(self): + def test_documents_and_count_last_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, skip=6, limit=3) self.assertEqual(len(groups), 3) self.assertEqual(count, 9) - def test_documents_and_count_partial_last_page(self): + def test_documents_and_count_partial_last_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, skip=8, limit=3) self.assertEqual(len(groups), 1) self.assertEqual(count, 9) diff --git a/src/eduid/scimapi/tests/test_login.py b/src/eduid/scimapi/tests/test_login.py index c5167836f..162cd76e0 100644 --- a/src/eduid/scimapi/tests/test_login.py +++ b/src/eduid/scimapi/tests/test_login.py @@ -15,12 +15,12 @@ def _get_config(self) -> dict: config["login_enabled"] = True return config - def test_get_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_get_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) self._assertResponse(response) - def test_use_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_use_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) token = response.headers.get("Authorization") headers = { "Content-Type": "application/scim+json", @@ -36,6 +36,6 @@ class TestLoginResourceNotEnabled(ScimApiTestCase): def setUp(self) -> None: super().setUp() - def test_get_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_get_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) assert response.status_code == 404 diff --git a/src/eduid/scimapi/tests/test_notifications.py b/src/eduid/scimapi/tests/test_notifications.py index fadf9d5e5..aaad751ac 100644 --- a/src/eduid/scimapi/tests/test_notifications.py +++ b/src/eduid/scimapi/tests/test_notifications.py @@ -11,7 +11,7 @@ class TestNotifications(ScimApiTestCase): - def _get_notifications(self): + def _get_notifications(self) -> list[QueueItem]: return [QueueItem.from_dict(x) for x in self.messagedb._get_all_docs()] def _get_config(self) -> dict[str, Any]: @@ -19,7 +19,7 @@ def _get_config(self) -> dict[str, Any]: config["data_owners"]["eduid.se"]["notify"] = ["https://example.org/notify"] return config - def test_create_user_notification(self): + def test_create_user_notification(self) -> None: assert len(self._get_notifications()) == 0 req = {"schemas": [SCIMSchema.CORE_20_USER.value], "externalId": "test-id-1"} @@ -28,7 +28,7 @@ def test_create_user_notification(self): assert len(self._get_notifications()) == 1 - def test_create_group_notification(self): + def test_create_group_notification(self) -> None: assert len(self._get_notifications()) == 0 req = {"schemas": [SCIMSchema.CORE_20_GROUP.value], "externalId": "test-id-1", "displayName": "Test Group"} @@ -37,7 +37,7 @@ def test_create_group_notification(self): assert len(self._get_notifications()) == 1 - def test_create_event_notification(self): + def test_create_event_notification(self) -> None: assert len(self._get_notifications()) == 0 user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") diff --git a/src/eduid/scimapi/tests/test_profile.py b/src/eduid/scimapi/tests/test_profile.py index a80515756..e058eafbe 100644 --- a/src/eduid/scimapi/tests/test_profile.py +++ b/src/eduid/scimapi/tests/test_profile.py @@ -4,7 +4,7 @@ class TestProfile(TestCase): - def test_parse(self): + def test_parse(self) -> None: displayname = "Musse Pigg" data = {"profiles": {"student": {"attributes": {"displayName": displayname}}}} extension = NutidUserExtensionV1.model_validate(data) diff --git a/src/eduid/scimapi/tests/test_scimbase.py b/src/eduid/scimapi/tests/test_scimbase.py index 6ca4a62ab..68b78e8d7 100644 --- a/src/eduid/scimapi/tests/test_scimbase.py +++ b/src/eduid/scimapi/tests/test_scimbase.py @@ -34,7 +34,7 @@ def test_base_response(self) -> None: loaded_base = BaseResponse.parse_raw(base_dump) assert normalised_data(base.dict()) == normalised_data(loaded_base.dict()) - def test_hashable_subresources(self): + def test_hashable_subresources(self) -> None: a = { "$ref": "http://localhost:8000/Users/78130160-b63d-4303-99cd-73767e2a999f", "display": "Test User 1 (updated)", diff --git a/src/eduid/scimapi/tests/test_scimevent.py b/src/eduid/scimapi/tests/test_scimevent.py index 27fcbb12b..79ab1b15d 100644 --- a/src/eduid/scimapi/tests/test_scimevent.py +++ b/src/eduid/scimapi/tests/test_scimevent.py @@ -26,8 +26,9 @@ class TestEventResource(ScimApiTestCase): def setUp(self) -> None: super().setUp() - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.eventdb self.eventdb._drop_whole_collection() def _create_event(self, event: dict[str, Any], expect_success: bool = True) -> EventApiResult: @@ -49,11 +50,11 @@ def _fetch_event(self, event_id: UUID) -> EventApiResult: parsed_response = EventResponse.parse_raw(response.text) return EventApiResult(event=parsed_response.nutid_event_v1, response=response, parsed_response=parsed_response) - def _assertEventUpdateSuccess(self, req: Mapping, response, event: ScimApiEvent): + def _assertEventUpdateSuccess(self, req: Mapping, response: Response, event: ScimApiEvent) -> None: """Function to validate successful responses to SCIM calls that update a event according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json()}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json()}") expected_schemas = req.get("schemas", [SCIMSchema.NUTID_EVENT_CORE_V1.value]) if ( @@ -65,7 +66,7 @@ def _assertEventUpdateSuccess(self, req: Mapping, response, event: ScimApiEvent) self._assertScimResponseProperties(response, resource=event, expected_schemas=expected_schemas) - def test_create_event(self): + def test_create_event(self) -> None: user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") event = { "resource": { @@ -80,6 +81,7 @@ def test_create_event(self): result = self._create_event(event=event) # check that the create resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, scim_id=user.scim_id) assert len(events) == 1 db_event = events[0] @@ -92,9 +94,10 @@ def test_create_event(self): assert db_event.data == event["data"] # Verify what is returned in the parsed_response assert result.parsed_response.id == db_event.scim_id + assert result.request self._assertEventUpdateSuccess(req=result.request, response=result.response, event=db_event) - def test_create_and_fetch_event(self): + def test_create_and_fetch_event(self) -> None: user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") event = { "resource": { @@ -109,6 +112,7 @@ def test_create_and_fetch_event(self): created = self._create_event(event=event) # check that the creation resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, scim_id=user.scim_id) assert len(events) == 1 db_event = events[0] diff --git a/src/eduid/scimapi/tests/test_scimgroup.py b/src/eduid/scimapi/tests/test_scimgroup.py index d098d303d..4da16d94a 100644 --- a/src/eduid/scimapi/tests/test_scimgroup.py +++ b/src/eduid/scimapi/tests/test_scimgroup.py @@ -7,6 +7,7 @@ from uuid import UUID, uuid4 from bson import ObjectId +from httpx import Response from eduid.common.config.base import DataOwnerName from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema, WeakVersion @@ -55,8 +56,9 @@ def setUp(self) -> None: super().setUp() self.groupdb = self.context.get_groupdb(DataOwnerName("eduid.se")) - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.groupdb self.groupdb._drop_whole_collection() def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtensions | None = None) -> ScimApiGroup: @@ -88,7 +90,7 @@ def _perform_search( expected_group: ScimApiGroup | None = None, expected_num_resources: int | None = None, expected_total_results: int | None = None, - ): + ) -> dict: logger.info(f"Searching for group(s) using filter {repr(filter)}") req = { "schemas": [SCIMSchema.API_MESSAGES_20_SEARCH_REQUEST.value], @@ -145,10 +147,10 @@ def _perform_search( return resources - def _assertGroupUpdateSuccess(self, req: Mapping, response, group: ScimApiGroup): + def _assertGroupUpdateSuccess(self, req: Mapping, response: Response, group: ScimApiGroup) -> None: """Function to validate successful responses to SCIM calls that update a group according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.CORE_20_GROUP.value]) if ( @@ -181,12 +183,13 @@ def _assertGroupUpdateSuccess(self, req: Mapping, response, group: ScimApiGroup) class TestGroupResource_GET(TestGroupResource): - def test_get_groups(self): + def test_get_groups(self) -> None: for i in range(9): self.add_group(uuid4(), f"Test Group {i}") response = self.client.get(url="/Groups", headers=self.headers) self.assertEqual([SCIMSchema.API_MESSAGES_20_LIST_RESPONSE.value], response.json().get("schemas")) resources = response.json().get("Resources") + assert self.groupdb expected_num_resources = self.groupdb.graphdb.db.count_nodes() self.assertEqual( expected_num_resources, @@ -199,18 +202,18 @@ def test_get_groups(self): f"Response totalResults does not match number of groups in the database: {expected_num_resources}", ) - def test_get_group(self): + def test_get_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") response = self.client.get(url=f"/Groups/{db_group.scim_id}", headers=self.headers) self._assertGroupUpdateSuccess({"members": []}, response, db_group) - def test_get_group_not_found(self): + def test_get_group_not_found(self) -> None: response = self.client.get(url=f"/Groups/{uuid4()}", headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") class TestGroupResource_POST(TestGroupResource): - def test_create_group(self): + def test_create_group(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], "displayName": "Test Group 1", @@ -220,6 +223,8 @@ def test_create_group(self): response = self.client.post(url="/Groups/", json=req, headers=self.headers) # Load the created group from the database, ensuring it was in fact created + assert self.groupdb + assert isinstance(req["displayName"], str) # please mypy _groups, _count = self.groupdb.get_groups_by_property("display_name", req["displayName"]) self.assertEqual(1, _count, "More or less than one group found in the database after create") db_group = _groups[0] @@ -227,13 +232,14 @@ def test_create_group(self): self._assertGroupUpdateSuccess(req, response, db_group) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing displayName req = {"schemas": [SCIMSchema.CORE_20_GROUP.value], "members": []} response = self.client.post(url="/Groups/", json=req, headers=self.headers) @@ -252,7 +258,7 @@ def test_schema_violation(self): class TestGroupResource_PUT(TestGroupResource): - def test_update_group(self): + def test_update_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") user = self.add_user(identifier=str(uuid4()), external_id="not-used") @@ -279,7 +285,7 @@ def test_update_group(self): self._assertGroupUpdateSuccess(req, response, db_group) - def test_update_existing_group(self): + def test_update_existing_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") user = self.add_user(identifier=str(uuid4()), external_id="not-used") @@ -318,13 +324,14 @@ def test_update_existing_group(self): self._assertGroupUpdateSuccess(req, response, db_group) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group.scim_id) assert len(events) == 2 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.UPDATED.value - def test_add_member_to_existing_group(self): + def test_add_member_to_existing_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") user = self.add_user(identifier=str(uuid4()), external_id="not-used") members = [ @@ -360,15 +367,18 @@ def test_add_member_to_existing_group(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertGroupUpdateSuccess(req, response, db_group) - def test_removing_group_member(self): + def test_removing_group_member(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") db_group = self.add_member(db_group, subgroup, "Test User") user = self.add_user(identifier=str(uuid4()), external_id="not-used") db_group = self.add_member(db_group, user, "Test User") + assert self.groupdb + # Load group to verify it has two members _g1 = self.groupdb.get_group_by_scim_id(str(db_group.scim_id)) + assert _g1 self.assertEqual(2, len(_g1.graph.members), "Group loaded from database does not have two members") self.assertEqual(1, len(_g1.graph.member_users), "Group loaded from database does not have one member user") self.assertEqual(1, len(_g1.graph.member_groups), "Group loaded from database does not have one member group") @@ -395,11 +405,12 @@ def test_removing_group_member(self): # Load group to verify it has one less member now _g2 = self.groupdb.get_group_by_scim_id(str(db_group.scim_id)) + assert _g2 self.assertEqual(1, len(_g2.graph.members), "Group loaded from database does not have two members") self.assertEqual(1, len(_g2.graph.member_users), "Group loaded from database does not have one member user") self.assertEqual(0, len(_g2.graph.member_groups), "Group loaded from database does not have one member group") - def test_update_group_id_mismatch(self): + def test_update_group_id_mismatch(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -410,7 +421,7 @@ def test_update_group_id_mismatch(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail="Id mismatch") - def test_update_group_not_found(self): + def test_update_group_not_found(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], "id": str(uuid4()), @@ -420,7 +431,7 @@ def test_update_group_not_found(self): response = self.client.put(url=f'/Groups/{req["id"]}', json=req, headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") - def test_version_mismatch(self): + def test_version_mismatch(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -431,7 +442,7 @@ def test_version_mismatch(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail="Version mismatch") - def test_update_group_member_does_not_exist(self): + def test_update_group_member_does_not_exist(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") _user_scim_id = str(uuid4()) members = [ @@ -451,7 +462,7 @@ def test_update_group_member_does_not_exist(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail=f"User {_user_scim_id} not found") - def test_update_group_subgroup_does_not_exist(self): + def test_update_group_subgroup_does_not_exist(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") _subgroup_scim_id = str(uuid4()) members = [ @@ -471,7 +482,7 @@ def test_update_group_subgroup_does_not_exist(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail=f"Group {_subgroup_scim_id} not found") - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing displayName req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -493,20 +504,22 @@ def test_schema_violation(self): class TestGroupResource_DELETE(TestGroupResource): - def test_delete_group(self): + def test_delete_group(self) -> None: group = self.add_group(uuid4(), "Test Group 1") + assert self.groupdb + assert self.eventdb # Verify we can find the group in the database db_group1 = self.groupdb.get_group_by_scim_id(str(group.scim_id)) - self.assertIsNotNone(db_group1) + assert db_group1 is not None self.headers["IF-MATCH"] = make_etag(group.version) response = self.client.delete(url=f"/Groups/{group.scim_id}", headers=self.headers) self.assertEqual(204, response.status_code) # Verify the group is no longer in the database - db_group2 = self.groupdb.get_group_by_scim_id(group.scim_id) - self.assertIsNone(db_group2) + db_group2 = self.groupdb.get_group_by_scim_id(str(group.scim_id)) + assert db_group2 is None # check that the action resulted in an event in the database events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group1.scim_id) @@ -515,50 +528,51 @@ def test_delete_group(self): assert event.resource.external_id is None assert event.data["status"] == EventStatus.DELETED.value - def test_version_mismatch(self): + def test_version_mismatch(self) -> None: group = self.add_group(uuid4(), "Test Group 1") self.headers["IF-MATCH"] = make_etag(ObjectId()) response = self.client.delete(url=f"/Groups/{group.scim_id}", headers=self.headers) self._assertScimError(response.json(), detail="Version mismatch") - def test_group_not_found(self): + def test_group_not_found(self) -> None: response = self.client.delete(url=f"/Groups/{uuid4()}", headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") class TestGroupSearchResource(TestGroupResource): - def test_search_group_display_name(self): + def test_search_group_display_name(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") self.add_group(uuid4(), "Test Group 2") self._perform_search(filter='displayName eq "Test Group 1"', expected_group=db_group) - def test_search_group_display_name_not_found(self): + def test_search_group_display_name_not_found(self) -> None: self._perform_search(filter='displayName eq "Test No Such Group"', expected_total_results=0) - def test_search_group_display_name_bad_operator(self): + def test_search_group_display_name_bad_operator(self) -> None: json = self._perform_search(filter="displayName lt 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Unsupported operator") - def test_search_group_display_name_not_string(self): + def test_search_group_display_name_not_string(self) -> None: json = self._perform_search(filter="displayName eq 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Invalid displayName") - def test_search_group_unknown_attribute(self): + def test_search_group_unknown_attribute(self) -> None: json = self._perform_search(filter="no_such_attribute lt 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Can't filter on attribute no_such_attribute") - def test_search_group_start_index(self): + def test_search_group_start_index(self) -> None: for i in range(9): self.add_group(uuid4(), "Test Group") self._perform_search( filter='displayName eq "Test Group"', start=5, expected_num_resources=5, expected_total_results=9 ) - def test_search_group_count(self): + def test_search_group_count(self) -> None: for i in range(9): self.add_group(uuid4(), "Test Group") + assert self.groupdb groups = self.groupdb.get_groups() self.assertEqual(len(groups), 9) @@ -566,24 +580,24 @@ def test_search_group_count(self): filter='displayName eq "Test Group"', start=1, count=5, expected_num_resources=5, expected_total_results=9 ) - def test_search_group_extension_data_attribute_str(self): + def test_search_group_extension_data_attribute_str(self) -> None: ext = GroupExtensions(data={"some_key": "20072009"}) db_group = self.add_group(uuid4(), "Test Group with extension", extensions=ext) self._perform_search(filter='extensions.data.some_key eq "20072009"', expected_group=db_group) - def test_search_group_extension_data_bad_op(self): + def test_search_group_extension_data_bad_op(self) -> None: json = self._perform_search(filter='extensions.data.some_key XY "20072009"', return_json=True) self._assertScimError(json, detail="Unsupported operator") - def test_search_group_extension_data_invalid_key(self): + def test_search_group_extension_data_invalid_key(self) -> None: json = self._perform_search(filter='extensions.data.some.key eq "20072009"', return_json=True) self._assertScimError(json, detail="Unsupported extension search key") - def test_search_group_extension_data_not_found(self): + def test_search_group_extension_data_not_found(self) -> None: self._perform_search(filter='extensions.data.some_key eq "20072009"', expected_num_resources=0) - def test_search_group_extension_data_attribute_int(self): + def test_search_group_extension_data_attribute_int(self) -> None: ext1 = GroupExtensions(data={"some_key": 20072009}) group = self.add_group(uuid4(), "Test Group with extension", extensions=ext1) @@ -593,7 +607,7 @@ def test_search_group_extension_data_attribute_int(self): self._perform_search(filter="extensions.data.some_key eq 20072009", expected_group=group) - def test_search_group_last_modified(self): + def test_search_group_last_modified(self) -> None: group1 = self.add_group(uuid4(), "Test Group 1") group2 = self.add_group(uuid4(), "Test Group 2") self.assertGreater(group2.last_modified, group1.last_modified) @@ -604,15 +618,15 @@ def test_search_group_last_modified(self): self._perform_search(filter=f'meta.lastModified gt "{group1.last_modified.isoformat()}"', expected_group=group2) - def test_search_group_last_modified_invalid_datetime_1(self): + def test_search_group_last_modified_invalid_datetime_1(self) -> None: json = self._perform_search(filter="meta.lastModified ge 1", return_json=True) self._assertScimError(json, detail="Invalid datetime") - def test_search_group_last_modified_invalid_datetime_2(self): + def test_search_group_last_modified_invalid_datetime_2(self) -> None: json = self._perform_search(filter='meta.lastModified ge "2020-05-12_15:36:99+00:00"', return_json=True) self._assertScimError(json, detail="Invalid datetime") - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing filter req = { "schemas": [SCIMSchema.API_MESSAGES_20_SEARCH_REQUEST.value], @@ -633,7 +647,7 @@ def test_schema_violation(self): class TestGroupExtensionData(TestGroupResource): - def test_nutid_extension(self): + def test_nutid_extension(self) -> None: display_name = "Test Group with Nutid extension" nutid_data = {"data": {"testing": "certainly"}} req = { @@ -647,7 +661,9 @@ def test_nutid_extension(self): # Load the newly created group from the database in order to validate the SCIM parsed_response better scim_id = post_resp.json().get("id") self.assertIsNotNone(scim_id, "Group creation parsed_response id not present") + assert self.groupdb db_group = self.groupdb.get_group_by_scim_id(scim_id) + assert db_group expected_schemas = [SCIMSchema.CORE_20_GROUP.value, SCIMSchema.NUTID_GROUP_V1.value] self._assertScimResponseProperties(post_resp, db_group, expected_schemas=expected_schemas) self.assertEqual([], post_resp.json().get("members"), "Group was not expected to have members") @@ -685,7 +701,9 @@ def test_nutid_extension(self): get_resp2 = self.client.get(url=f"/Groups/{scim_id}", headers=self.headers) self.assertEqual(put_resp.json(), get_resp2.json()) + assert self.groupdb db_group = self.groupdb.get_group_by_scim_id(scim_id) + assert db_group self._assertScimResponseProperties(get_resp2, db_group, expected_schemas=expected_schemas) self.assertEqual([], get_resp2.json().get("members"), "Group was not expected to have members") diff --git a/src/eduid/scimapi/tests/test_sciminvite.py b/src/eduid/scimapi/tests/test_sciminvite.py index d049e7d41..41ed68089 100644 --- a/src/eduid/scimapi/tests/test_sciminvite.py +++ b/src/eduid/scimapi/tests/test_sciminvite.py @@ -1,13 +1,14 @@ import json import logging import unittest -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import copy from dataclasses import asdict from datetime import datetime, timedelta from typing import Any from bson import ObjectId +from httpx import Response from eduid.common.misc.timeutil import utc_now from eduid.common.models.scim_base import Email, Meta, Name, PhoneNumber, SCIMResourceType, SCIMSchema @@ -55,7 +56,7 @@ def setUp(self) -> None: "profiles": {"student": {"attributes": {"displayName": "Test"}}}, } - def test_load_invite(self): + def test_load_invite(self) -> None: invite = ScimApiInvite.from_dict(self.invite_doc1) # test to-dict+from-dict consistency invite2 = ScimApiInvite.from_dict(invite.to_dict()) @@ -63,8 +64,9 @@ def test_load_invite(self): assert asdict(invite) == asdict(invite2) assert invite.to_dict() == invite2.to_dict() - def test_to_sciminvite_response(self): + def test_to_sciminvite_response(self) -> None: db_invite = ScimApiInvite.from_dict(self.invite_doc1) + assert db_invite meta = Meta( location=f"http://example.org/Invites/{db_invite.scim_id}", resource_type=SCIMResourceType.INVITE, @@ -73,6 +75,8 @@ def test_to_sciminvite_response(self): version=db_invite.version, ) + assert db_invite.emails[0].primary is not None + assert db_invite.emails[1].primary is not None signup_invite = SignupInvite( invite_type=InviteType.SCIM, invite_reference=SCIMReference(data_owner="test_data_owner", scim_id=db_invite.scim_id), @@ -239,10 +243,12 @@ def add_invite(self, data: dict[str, Any] | None = None, update: bool = False) - self.signup_invitedb.save(signup_invite, is_in_database=False) return db_invite - def _assertUpdateSuccess(self, req: Mapping, response, invite: ScimApiInvite, signup_invite: SignupInvite): + def _assertUpdateSuccess( + self, req: Mapping, response: Response, invite: ScimApiInvite, signup_invite: SignupInvite + ) -> None: """Function to validate successful responses to SCIM calls that update an invite according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.NUTID_INVITE_V1.value, SCIMSchema.NUTID_USER_V1.value]) @@ -288,7 +294,7 @@ def _perform_search( expected_invite: ScimApiInvite | None = None, expected_num_resources: int | None = None, expected_total_results: int | None = None, - ): + ) -> dict: logger.info(f"Searching for group(s) using filter {repr(filter)}") req = { "schemas": [SCIMSchema.API_MESSAGES_20_SEARCH_REQUEST.value], @@ -341,7 +347,7 @@ def _perform_search( resources = response.json().get("Resources") return resources - def test_create_invite(self): + def test_create_invite(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -379,20 +385,24 @@ def test_create_invite(self): response = self.client.post(url="/Invites/", json=req, headers=self.headers) self._assertResponse(response, status_code=201) + assert self.invitedb db_invite = self.invitedb.get_invite_by_scim_id(response.json().get("id")) + assert db_invite reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) signup_invite = self.signup_invitedb.get_invite_by_reference(reference) + assert signup_invite self._assertUpdateSuccess(req, response, db_invite, signup_invite) self.assertEqual(1, self.messagedb.db_count()) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.INVITE, db_invite.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_create_invite_missing_mandatory_attributes(self): + def test_create_invite_missing_mandatory_attributes(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -405,7 +415,7 @@ def test_create_invite_missing_mandatory_attributes(self): }, } - req1 = copy(req) + req1: MutableMapping[str, Any] = copy(req) del req1[SCIMSchema.NUTID_INVITE_V1.value]["inviterName"] response = self.client.post(url="/Invites/", json=req1, headers=self.headers) self._assertScimError( @@ -422,7 +432,7 @@ def test_create_invite_missing_mandatory_attributes(self): exclude_keys=["input", "url"], ) - req2 = copy(req) + req2: MutableMapping[str, Any] = copy(req) del req2[SCIMSchema.NUTID_INVITE_V1.value]["sendEmail"] response = self.client.post(url="/Invites/", json=req2, headers=self.headers) self._assertScimError( @@ -439,7 +449,7 @@ def test_create_invite_missing_mandatory_attributes(self): exclude_keys=["input", "url"], ) - def test_create_invite_do_not_send_email(self): + def test_create_invite_do_not_send_email(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -476,13 +486,16 @@ def test_create_invite_do_not_send_email(self): response = self.client.post(url="/Invites/", json=req, headers=self.headers) self._assertResponse(response, status_code=201) + assert self.invitedb db_invite = self.invitedb.get_invite_by_scim_id(response.json().get("id")) + assert db_invite reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) signup_invite = self.signup_invitedb.get_invite_by_reference(reference) + assert signup_invite self._assertUpdateSuccess(req, response, db_invite, signup_invite) self.assertEqual(0, self.messagedb.db_count()) - def test_get_invite(self): + def test_get_invite(self) -> None: db_invite = self.add_invite() response = self.client.get(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) expected_schemas = [ @@ -492,7 +505,7 @@ def test_get_invite(self): ] self._assertScimResponseProperties(response, resource=db_invite, expected_schemas=expected_schemas) - def test_update_invite(self): + def test_update_invite(self) -> None: # TODO: For now we only support updating completed db_invite = self.add_invite() invite = self.client.get(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) @@ -505,7 +518,9 @@ def test_update_invite(self): response = self.client.put(url=f"/Invites/{db_invite.scim_id}", json=update_req, headers=self.headers) assert response.status_code == 200 + assert self.invitedb updated_invite = self.invitedb.get_invite_by_scim_id(str(db_invite.scim_id)) + assert updated_invite assert updated_invite.completed is not None expected_schemas = [ @@ -515,21 +530,23 @@ def test_update_invite(self): ] self._assertScimResponseProperties(response, resource=db_invite, expected_schemas=expected_schemas) - def test_delete_invite(self): + def test_delete_invite(self) -> None: db_invite = self.add_invite() self.headers["IF-MATCH"] = make_etag(db_invite.version) self.client.delete(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) + assert self.invitedb self.assertIsNone(self.invitedb.get_invite_by_scim_id(str(db_invite.scim_id))) self.assertIsNone(self.signup_invitedb.get_invite_by_reference(reference)) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.INVITE, db_invite.scim_id) assert len(events) == 1 event = events[0] assert event.data["status"] == EventStatus.DELETED.value - def test_search_user_last_modified(self): + def test_search_user_last_modified(self) -> None: db_invite1 = self.add_invite() db_invite2 = self.add_invite(data={"invite_code": "another_invite_code"}, update=True) self.assertGreater(db_invite2.last_modified, db_invite1.last_modified) diff --git a/src/eduid/scimapi/tests/test_scimuser.py b/src/eduid/scimapi/tests/test_scimuser.py index aad4f48c2..5ddaf91b7 100644 --- a/src/eduid/scimapi/tests/test_scimuser.py +++ b/src/eduid/scimapi/tests/test_scimuser.py @@ -19,7 +19,7 @@ from eduid.common.utils import make_etag from eduid.scimapi.testing import ScimApiTestCase from eduid.scimapi.utils import filter_none -from eduid.userdb.scimapi import EventStatus, ScimApiLinkedAccount +from eduid.userdb.scimapi import EventStatus, ScimApiGroup, ScimApiLinkedAccount from eduid.userdb.scimapi.userdb import ScimApiProfile, ScimApiUser logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def setUp(self) -> None: "profiles": {"student": {"attributes": {"displayName": "Test"}}}, } - def test_load_old_user(self): + def test_load_old_user(self) -> None: user = ScimApiUser.from_dict(self.user_doc1) self.assertEqual(user.profiles["student"].attributes["displayName"], "Test") @@ -57,7 +57,7 @@ def test_load_old_user(self): user2 = ScimApiUser.from_dict(user.to_dict()) self.assertEqual(asdict(user), asdict(user2)) - def test_to_scimuser_doc(self): + def test_to_scimuser_doc(self) -> None: db_user = ScimApiUser.from_dict(self.user_doc1) meta = Meta( location=f"http://example.org/Users/{db_user.scim_id}", @@ -114,7 +114,7 @@ def test_to_scimuser_doc(self): loaded_user_response = json.loads(user_response_json) assert loaded_user_response == expected - def test_to_scimuser_no_external_id(self): + def test_to_scimuser_no_external_id(self) -> None: user_doc2 = { "_id": ObjectId("5e81c5f849ac2cd87580e500"), "scim_id": "a7851d21-eab9-4caa-ba5d-49653d65c452", @@ -167,7 +167,7 @@ def test_to_scimuser_no_external_id(self): loaded_user_response = json.loads(user_response_json) assert loaded_user_response == expected - def test_bson_serialization(self): + def test_bson_serialization(self) -> None: user = ScimApiUser.from_dict(self.user_doc1) x = bson.encode(user.to_dict()) self.assertTrue(x) @@ -189,11 +189,11 @@ def setUp(self) -> None: attributes={"displayName": "Test User 2"}, data={"another_test_key": "another_test_value"} ) - def _assertUserUpdateSuccess(self, req: Mapping, response, user: ScimApiUser): + def _assertUserUpdateSuccess(self, req: Mapping, response: Response, user: ScimApiUser) -> None: """Function to validate successful responses to SCIM calls that update a user according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.CORE_20_USER.value]) if SCIMSchema.NUTID_USER_V1.value in response.json() and SCIMSchema.NUTID_USER_V1.value not in expected_schemas: @@ -226,7 +226,7 @@ def _assertUserUpdateSuccess(self, req: Mapping, response, user: ScimApiUser): resp_nutid, "Unexpected NUTID user data in parsed_response", ) - elif SCIMSchema.NUTID_USER_V1.value in response.json: + elif SCIMSchema.NUTID_USER_V1.value in response.json(): self.fail(f"Unexpected {SCIMSchema.NUTID_USER_V1.value} in the parsed_response") def _create_user(self, req: dict[str, Any], expect_success: bool = True) -> UserApiResult: @@ -278,7 +278,7 @@ def _update_user( class TestUserResource(ScimApiTestUserResourceBase): - def test_get_user(self): + def test_get_user(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self.client.get(url=f"/Users/{db_user.scim_id}", headers=self.headers) @@ -287,11 +287,11 @@ def test_get_user(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_create_users_with_no_external_id(self): + def test_create_users_with_no_external_id(self) -> None: self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) - def test_create_user(self): + def test_create_user(self) -> None: req = { "externalId": "test-id-1", "name": {"familyName": "Testsson", "givenName": "Test", "middleName": "Testaren"}, @@ -308,19 +308,23 @@ def test_create_user(self): result = self._create_user(req) # Load the created user from the database, ensuring it was in fact created + assert self.userdb + assert isinstance(req["externalId"], str) db_user = self.userdb.get_user_by_external_id(req["externalId"]) + assert db_user self.assertIsNotNone(db_user, "Created user not found in the database") self._assertUserUpdateSuccess(result.request, result.response, db_user) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, db_user.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_create_and_update_user(self): + def test_create_and_update_user(self) -> None: """Test that creating a user and then updating it without changes only results in one event""" req = { "externalId": "test-id-1", @@ -335,8 +339,10 @@ def test_create_and_update_user(self): }, } result1 = self._create_user(req) + assert result1.parsed_response # check that the action resulted in an event in the database + assert self.eventdb events1 = self.eventdb.get_events_by_resource(SCIMResourceType.USER, result1.parsed_response.id) assert len(events1) == 1 event = events1[0] @@ -345,6 +351,7 @@ def test_create_and_update_user(self): # Update the user without making any changes result2 = self._update_user(req, result1.parsed_response.id, result1.parsed_response.meta.version) + assert result2.parsed_response # Make sure the version wasn't updated assert result1.parsed_response.meta.version == result2.parsed_response.meta.version # Make sure no additional event was created @@ -352,7 +359,7 @@ def test_create_and_update_user(self): assert len(events2) == 1 assert events1 == events2 - def test_create_user_no_external_id(self): + def test_create_user_no_external_id(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], SCIMSchema.NUTID_USER_V1.value: { @@ -366,12 +373,14 @@ def test_create_user_no_external_id(self): self._assertResponse(response, status_code=201) # Load the created user from the database, ensuring it was in fact created + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self.assertIsNotNone(db_user, "Created user not found in the database") self._assertUserUpdateSuccess(req, response, db_user) - def test_create_user_duplicated_external_id(self): + def test_create_user_duplicated_external_id(self) -> None: external_id = "test-id-1" # Create an existing user in the db self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) @@ -388,8 +397,11 @@ def test_create_user_duplicated_external_id(self): response.json(), schemas=["urn:ietf:params:scim:api:messages:2.0:Error"], detail="externalID must be unique" ) - def test_update_user(self): - db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) + def test_update_user(self) -> None: + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile} + ) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -408,17 +420,20 @@ def test_update_user(self): self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.put(url=f"/Users/{db_user.scim_id}", json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, db_user.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.UPDATED.value - def test_update_user_change_properties(self): + def test_update_user_change_properties(self) -> None: # Create the user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], @@ -460,11 +475,14 @@ def test_update_user_change_properties(self): response = self.client.put(url=f'/Users/{create_response.json()["id"]}', json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) - def test_update_user_set_external_id(self): - db_user = self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) + def test_update_user_set_external_id(self) -> None: + db_user: ScimApiUser | None = self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -479,10 +497,12 @@ def test_update_user_set_external_id(self): self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.put(url=f"/Users/{db_user.scim_id}", json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) - def test_update_user_duplicated_external_id(self): + def test_update_user_duplicated_external_id(self) -> None: external_id = "test-id-1" # Create two existing users with different external_id self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) @@ -504,40 +524,52 @@ def test_update_user_duplicated_external_id(self): response.json(), schemas=["urn:ietf:params:scim:api:messages:2.0:Error"], detail="externalID must be unique" ) - def test_delete_user(self): + def test_delete_user(self) -> None: external_id = "test-id-1" - db_user = self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile} + ) + assert db_user user_scim_id = db_user.scim_id self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.delete(url=f"/Users/{db_user.scim_id}", headers=self.headers) self._assertResponse(response, status_code=204) # No content - db_user = self.userdb.get_user_by_scim_id(user_scim_id) + assert self.userdb + db_user = self.userdb.get_user_by_scim_id(str(user_scim_id)) assert db_user is None # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, user_scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == external_id assert event.data["status"] == EventStatus.DELETED.value - def test_delete_user_with_groups(self): + def test_delete_user_with_groups(self) -> None: external_id = "test-id-1" - db_user = self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile} + ) + assert db_user user_scim_id = db_user.scim_id - group1 = self.add_group_with_member( + group1: ScimApiGroup | None = self.add_group_with_member( group_identifier=str(uuid4()), display_name="Group 1", user_identifier=str(user_scim_id) ) + assert group1 group1 = self.add_owner_to_group(group_identifier=str(group1.scim_id), user_identifier=str(user_scim_id)) + assert group1 extra_user = self.add_user( identifier=str(uuid4()), external_id="other external id", profiles={"test": self.test_profile} ) - group2 = self.add_group_with_member( + group2: ScimApiGroup | None = self.add_group_with_member( group_identifier=str(uuid4()), display_name="Group 2", user_identifier=str(user_scim_id) ) + assert group2 group2 = self.add_member_to_group(group_identifier=str(group2.scim_id), user_identifier=str(extra_user.scim_id)) + assert group2 assert len(group1.members) == 1 assert len(group1.owners) == 1 @@ -547,28 +579,33 @@ def test_delete_user_with_groups(self): response = self.client.delete(url=f"/Users/{db_user.scim_id}", headers=self.headers) self._assertResponse(response, status_code=204) # No content - db_user = self.userdb.get_user_by_scim_id(user_scim_id) + assert self.userdb + db_user = self.userdb.get_user_by_scim_id(str(user_scim_id)) assert db_user is None + assert self.groupdb group1 = self.groupdb.get_group_by_scim_id(str(group1.scim_id)) + assert group1 assert len(group1.graph.members) == 0 assert len(group1.graph.owners) == 0 group2 = self.groupdb.get_group_by_scim_id(str(group2.scim_id)) + assert group2 assert len(group2.graph.members) == 1 # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, user_scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == external_id assert event.data["status"] == EventStatus.DELETED.value - def test_search_user_external_id(self): + def test_search_user_external_id(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile}) self._perform_search(search_filter=f'externalId eq "{db_user.external_id}"', expected_user=db_user) - def test_search_user_last_modified(self): + def test_search_user_last_modified(self) -> None: db_user1 = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) db_user2 = self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile}) self.assertGreater(db_user2.last_modified, db_user1.last_modified) @@ -583,14 +620,15 @@ def test_search_user_last_modified(self): search_filter=f'meta.lastModified gt "{db_user1.last_modified.isoformat()}"', expected_user=db_user2 ) - def test_search_user_profile_data(self): + def test_search_user_profile_data(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile2}) self._perform_search(search_filter='profiles.test.data.test_key eq "test_value"', expected_user=db_user) - def test_search_user_start_index(self): + def test_search_user_start_index(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -601,9 +639,10 @@ def test_search_user_start_index(self): expected_total_results=9, ) - def test_search_user_count(self): + def test_search_user_count(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -614,9 +653,10 @@ def test_search_user_count(self): expected_total_results=9, ) - def test_search_user_start_index_and_count(self): + def test_search_user_start_index_and_count(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -628,7 +668,7 @@ def test_search_user_start_index_and_count(self): expected_total_results=9, ) - def test_create_and_update_user_with_linked_accounts(self): + def test_create_and_update_user_with_linked_accounts(self) -> None: """Test that creating a user and then updating it without changes only results in one event""" account = LinkedAccount(issuer="eduid.se", value="test@dev.eduid.se") _db_account = ScimApiLinkedAccount(issuer=account.issuer, value=account.value, parameters=account.parameters) @@ -636,6 +676,9 @@ def test_create_and_update_user_with_linked_accounts(self): SCIMSchema.NUTID_USER_V1.value: {"profiles": {}, "linked_accounts": [account.to_dict()]}, } result1 = self._create_user(req) + assert result1.parsed_response + + assert self.userdb self._assertResponse(result1.response, status_code=201) db_user = self.userdb.get_user_by_scim_id(str(result1.parsed_response.id)) @@ -652,6 +695,7 @@ def test_create_and_update_user_with_linked_accounts(self): # Update the user result2 = self._update_user(req, result1.parsed_response.id, result1.parsed_response.meta.version) + assert result2.parsed_response self._assertResponse(result2.response, status_code=200) db_user = self.userdb.get_user_by_scim_id(str(result2.parsed_response.id)) assert db_user @@ -663,7 +707,7 @@ def test_create_and_update_user_with_linked_accounts(self): # Verify the updated account made it into the database assert db_user.linked_accounts == [_db_account] - def test_create_user_with_invalid_linked_accounts_issuer(self): + def test_create_user_with_invalid_linked_accounts_issuer(self) -> None: """Test that creating a user with an invalid issuer and valid value fails""" account = LinkedAccount(issuer="NOT-eduid.se", value="test@dev.eduid.se") req = { @@ -674,7 +718,7 @@ def test_create_user_with_invalid_linked_accounts_issuer(self): result1 = self._create_user(req, expect_success=False) self._assertScimError(json=result1.response.json(), detail="Invalid nutid linked_accounts") - def test_create_user_with_invalid_linked_accounts_value(self): + def test_create_user_with_invalid_linked_accounts_value(self) -> None: """Test that creating a user with valid issuer and invalid value fails""" account = LinkedAccount(issuer="eduid.se", value="test@eduid.com") req = { @@ -685,7 +729,7 @@ def test_create_user_with_invalid_linked_accounts_value(self): result1 = self._create_user(req, expect_success=False) self._assertScimError(json=result1.response.json(), detail="Invalid nutid linked_accounts") - def test_update_user_set_linked_accounts(self): + def test_update_user_set_linked_accounts(self) -> None: db_account1 = ScimApiLinkedAccount(issuer="eduid.se", value="test1@dev.eduid.se") account2 = LinkedAccount(issuer="eduid.se", value="test2@eduid.se", parameters={"mfa_stepup": True}) db_user = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) @@ -698,12 +742,13 @@ def test_update_user_set_linked_accounts(self): self._assertResponse(result.response) self._assertUserUpdateSuccess(req, result.response, db_user) - def test_update_user_set_linked_accounts2(self): + def test_update_user_set_linked_accounts2(self) -> None: """Test updating linked accounts sorted 'wrong'""" db_account1 = ScimApiLinkedAccount(issuer="eduid.se", value="test1@dev.eduid.se") account1 = LinkedAccount(issuer=db_account1.issuer, value=db_account1.value) account2 = LinkedAccount(issuer="eduid.se", value="test2@eduid.se", parameters={"mfa_stepup": True}) - db_user = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) + db_user: ScimApiUser | None = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -715,6 +760,7 @@ def test_update_user_set_linked_accounts2(self): result = self._update_user(req, db_user.scim_id, version=db_user.version) self._assertResponse(result.response) self._assertUserUpdateSuccess(req, result.response, db_user) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(str(db_user.scim_id)) assert db_user db_account2 = ScimApiLinkedAccount(issuer=account2.issuer, value=account2.value, parameters=account2.parameters) @@ -729,7 +775,7 @@ def _perform_search( expected_user: ScimApiUser | None = None, expected_num_resources: int | None = None, expected_total_results: int | None = None, - ): + ) -> dict: logger.info(f"Searching for user(s) using filter {repr(search_filter)}") req = { "schemas": [SCIMSchema.API_MESSAGES_20_SEARCH_REQUEST.value], @@ -801,7 +847,8 @@ def setUp(self) -> None: for user in self.users[1:]: self.add_member_to_group(group_identifier=str(self.group.scim_id), user_identifier=str(user.scim_id)) - async def test_delete_user_with_groups(self): + async def test_delete_user_with_groups(self) -> None: + assert self.groupdb group = self.groupdb.get_group_by_scim_id(str(self.group.scim_id)) assert group is not None # please mypy assert len(group.members) == self.user_count @@ -820,8 +867,10 @@ async def test_delete_user_with_groups(self): for task in tasks: self._assertResponse(task.result(), status_code=204) # No content + assert self.userdb for user in self.users[: self.user_count // 2]: - assert self.userdb.get_user_by_scim_id(user.scim_id) is None + assert self.userdb.get_user_by_scim_id(str(user.scim_id)) is None group = self.groupdb.get_group_by_scim_id(str(group.scim_id)) + assert group assert len(group.graph.members) == self.user_count // 2 diff --git a/src/eduid/scimapi/tests/test_search_filter.py b/src/eduid/scimapi/tests/test_search_filter.py index 5e83bbf80..8e2c9da80 100644 --- a/src/eduid/scimapi/tests/test_search_filter.py +++ b/src/eduid/scimapi/tests/test_search_filter.py @@ -6,7 +6,7 @@ class TestSearchFilter(unittest.TestCase): - def test_lastmodified(self): + def test_lastmodified(self) -> None: now = datetime.utcnow() filter = f'meta.lastModified gt "{now.isoformat()}"' sf = parse_search_filter(filter) @@ -14,7 +14,7 @@ def test_lastmodified(self): self.assertEqual(sf.op, "gt") self.assertEqual(sf.val, now.isoformat()) - def test_lastmodified_with_tz(self): + def test_lastmodified_with_tz(self) -> None: nowstr = "2020-05-05T09:13:43.916000+00:00" filter = f'meta.lastModified gt "{nowstr}"' sf = parse_search_filter(filter) @@ -22,21 +22,21 @@ def test_lastmodified_with_tz(self): self.assertEqual(sf.op, "gt") self.assertEqual(sf.val, nowstr) - def test_str(self): + def test_str(self) -> None: filter = 'foo eq "123"' sf = parse_search_filter(filter) self.assertEqual(sf.attr, "foo") self.assertEqual(sf.op, "eq") self.assertEqual(sf.val, "123") - def test_int(self): + def test_int(self) -> None: filter = "foo eq 123" sf = parse_search_filter(filter) self.assertEqual(sf.attr, "foo") self.assertEqual(sf.op, "eq") self.assertEqual(sf.val, 123) - def test_not_printable(self): + def test_not_printable(self) -> None: filter = "foo eq 12\u00093" with self.assertRaises(BadRequest): parse_search_filter(filter) diff --git a/src/eduid/scimapi/utils.py b/src/eduid/scimapi/utils.py index 0f04ec0a2..2eec1ef40 100644 --- a/src/eduid/scimapi/utils.py +++ b/src/eduid/scimapi/utils.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Callable -from typing import AnyStr, TypeVar +from typing import Any, AnyStr, TypeVar from uuid import uuid4 from jwcrypto import jwk @@ -48,7 +48,7 @@ def filter_none(x: Filtered) -> Filtered: return x -def get_unique_hash(): +def get_unique_hash() -> str: return str(uuid4()) @@ -61,9 +61,9 @@ def load_jwks(config: ScimApiConfig) -> jwk.JWKSet: return jwks -def retryable_db_write(func: Callable): +def retryable_db_write(func: Callable) -> Callable: @functools.wraps(func) - def wrapper_run_func(*args, **kwargs): + def wrapper_run_func(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 max_retries = 10 retry = 0 while True: diff --git a/src/eduid/userdb/actions/tou/userdb.py b/src/eduid/userdb/actions/tou/userdb.py index e24a42a81..edc75a75c 100644 --- a/src/eduid/userdb/actions/tou/userdb.py +++ b/src/eduid/userdb/actions/tou/userdb.py @@ -7,7 +7,7 @@ class ToUUserDB(UserDB[ToUUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_actions", collection: str = "tou"): + def __init__(self, db_uri: str, db_name: str = "eduid_actions", collection: str = "tou") -> None: super().__init__(db_uri, db_name, collection) @classmethod diff --git a/src/eduid/userdb/admin/__init__.py b/src/eduid/userdb/admin/__init__.py index 39d0a913e..d69f11db7 100644 --- a/src/eduid/userdb/admin/__init__.py +++ b/src/eduid/userdb/admin/__init__.py @@ -9,7 +9,6 @@ import time from collections.abc import Generator from copy import deepcopy -from typing import Any import bson import bson.json_util @@ -44,14 +43,14 @@ class RawDb: log detailing all the changes. """ - def __init__(self, myname: str | None = None, backupbase: str = "/root/raw_db_changes"): + def __init__(self, myname: str | None = None, backupbase: str = "/root/raw_db_changes") -> None: self._client = get_client() self._start_time: str = datetime.datetime.fromtimestamp(int(time.time())).isoformat(sep="_").replace(":", "") self._myname: str | None = myname self._backupbase: str = backupbase self._file_num: int = 0 - def find(self, db: str, collection: str, search_filter: Any) -> Generator[RawData, None, None]: + def find(self, db: str, collection: str, search_filter: object) -> Generator[RawData, None, None]: """ Look for documents matching search_filter in the specified database and collection. @@ -72,7 +71,7 @@ def find(self, db: str, collection: str, search_filter: Any) -> Generator[RawDat ) sys.exit(1) - def save_with_backup(self, raw: RawData, dry_run: bool = True) -> Any: + def save_with_backup(self, raw: RawData, dry_run: bool = True) -> str | None: """ Save a mongodb document while trying to carefully make a backup of the document before, after and what changed. @@ -113,7 +112,7 @@ def save_with_backup(self, raw: RawData, dry_run: bool = True) -> Any: if raw.before == raw.doc: sys.stderr.write(f"Document in {db_coll} with id {_id} not changed, aborting save_with_backup\n") - return + return None self._file_num = 0 backup_dir = self._make_backupdir(db_coll, _id) @@ -132,13 +131,13 @@ def save_with_backup(self, raw: RawData, dry_run: bool = True) -> Any: # Write changes.txt after saving, so it will also indicate a successful save return self._write_changes(raw, backup_dir, res) - def _write_changes(self, raw: RawData, backup_dir: str, res: Any) -> Any: + def _write_changes(self, raw: RawData, backup_dir: str, res: str) -> str: """ Write a file with one line per change between the before-doc and current doc. The format is intended to be easy to grep through. """ - def safe_encode(k2: Any, v2: Any) -> str: + def safe_encode(k2: object, v2: object) -> str: try: return bson.json_util.dumps({k2: v2}, json_options=PYTHON_UUID_LEGACY_JSON_OPTIONS) except: @@ -160,7 +159,7 @@ def safe_encode(k2: Any, v2: Any) -> str: fd.write(f"DB_RESULT: {res}\n") return res - def _write_before_and_after(self, raw: RawData, backup_dir: str): + def _write_before_and_after(self, raw: RawData, backup_dir: str) -> None: """ Write before- and after backup files of the document being saved, in JSON format. """ @@ -180,7 +179,7 @@ def _write_before_and_after(self, raw: RawData, backup_dir: str): + "\n" ) - def _get_backup_filename(self, dirname: str, filename: str, ext: str): + def _get_backup_filename(self, dirname: str, filename: str, ext: str) -> str: """ Look for a backup filename that hasn't been used. The use of self._file_num should mean we get matching before- after- and changes sets. @@ -237,7 +236,7 @@ class RawData: :param collection: Name of collection """ - def __init__(self, doc: TUserDbDocument, db: str, collection: str): + def __init__(self, doc: TUserDbDocument, db: str, collection: str) -> None: self._before = deepcopy(doc) self._db = db self._collection = collection diff --git a/src/eduid/userdb/authninfo.py b/src/eduid/userdb/authninfo.py index 1b3f92cb3..56ebcef5d 100644 --- a/src/eduid/userdb/authninfo.py +++ b/src/eduid/userdb/authninfo.py @@ -33,7 +33,7 @@ class AuthnInfoDB(BaseDB): TODO: We already have a database class to access this collection, in the IdP. Consolidate the two. """ - def __init__(self, db_uri, db_name="eduid_idp_authninfo", collection="authn_info"): + def __init__(self, db_uri: str, db_name: str = "eduid_idp_authninfo", collection: str = "authn_info") -> None: super().__init__(db_uri, db_name, collection) def get_authn_info(self, user: User) -> Mapping[ElementKey, AuthnInfoElement]: diff --git a/src/eduid/userdb/credentials/base.py b/src/eduid/userdb/credentials/base.py index b2a8acdf6..baf48d484 100644 --- a/src/eduid/userdb/credentials/base.py +++ b/src/eduid/userdb/credentials/base.py @@ -29,7 +29,7 @@ class Credential(VerifiedElement): proofing_method: CredentialProofingMethod | None = None - def __str__(self): + def __str__(self) -> str: if len(self.key) == 24: # probably an object id in string format, don't cut it shortkey = str(self.key) diff --git a/src/eduid/userdb/credentials/external.py b/src/eduid/userdb/credentials/external.py index 044c3bce0..b39042ba2 100644 --- a/src/eduid/userdb/credentials/external.py +++ b/src/eduid/userdb/credentials/external.py @@ -26,7 +26,7 @@ class ExternalCredential(Credential): @field_validator("credential_id", mode="before") @classmethod - def credential_id_objectid(cls, v): + def credential_id_objectid(cls, v: object) -> str: """Turn ObjectId into string""" if isinstance(v, ObjectId): v = str(v) diff --git a/src/eduid/userdb/credentials/password.py b/src/eduid/userdb/credentials/password.py index 291bda4b4..7a9edad4d 100644 --- a/src/eduid/userdb/credentials/password.py +++ b/src/eduid/userdb/credentials/password.py @@ -17,7 +17,7 @@ class Password(Credential): @field_validator("credential_id", mode="before") @classmethod - def credential_id_objectid(cls, v): + def credential_id_objectid(cls, v: object) -> str: """Turn ObjectId into string""" if isinstance(v, ObjectId): v = str(v) diff --git a/src/eduid/userdb/db/async_db.py b/src/eduid/userdb/db/async_db.py index 2247d04a3..292f82e1e 100644 --- a/src/eduid/userdb/db/async_db.py +++ b/src/eduid/userdb/db/async_db.py @@ -42,7 +42,7 @@ def __init__( db_uri: str, db_name: str | None = None, **kwargs: Any, - ): + ) -> None: super().__init__(db_uri=db_uri, db_name=db_name, **kwargs) try: self._client = AsyncClientCache().get_client(self) @@ -80,7 +80,7 @@ def get_collection(self, collection: str, database_name: str | None = None) -> A _db = self.get_database(database_name) return _db[collection] - async def is_healthy(self): + async def is_healthy(self) -> bool: """ From mongo_client.py: Starting with version 3.0 the :class:`MongoClient` @@ -111,7 +111,7 @@ async def is_healthy(self): logger.error(f"{self} not healthy: {e}") return False - async def close(self): + async def close(self) -> None: self._client.close() @@ -124,7 +124,7 @@ def __init__( db_name: str, collection: str, safe_writes: bool = False, - ): + ) -> None: self._db_uri = db_uri self._coll_name = collection self._db = AsyncMongoDB(db_uri, db_name=db_name) @@ -132,7 +132,7 @@ def __init__( if safe_writes: self._coll = self._coll.with_options(write_concern=pymongo.WriteConcern(w="majority")) - def __repr__(self): + def __repr__(self) -> str: return f"" __str__ = __repr__ @@ -149,7 +149,7 @@ def collection(self) -> AsyncIOMotorCollection: def connection(self) -> AsyncIOMotorClient: return self._db.get_connection() - async def _drop_whole_collection(self): + async def _drop_whole_collection(self) -> None: """ Drop the whole collection. Should ONLY be used in testing, obviously. :return: @@ -157,7 +157,7 @@ async def _drop_whole_collection(self): logger.warning(f"{self!s} Dropping collection {self._coll_name!r}") return await self._coll.drop() - async def _get_document_by_attr(self, attr: str, value: Any) -> Mapping[str, Any] | None: + async def _get_document_by_attr(self, attr: str, value: object) -> Mapping[str, Any] | None: """ Return the document in the MongoDB matching field=value diff --git a/src/eduid/userdb/db/base.py b/src/eduid/userdb/db/base.py index b6b0009ec..bd9c8b3c0 100644 --- a/src/eduid/userdb/db/base.py +++ b/src/eduid/userdb/db/base.py @@ -21,7 +21,7 @@ def __init__( db_uri: str, db_name: str | None = None, **kwargs: Any, - ): + ) -> None: if db_uri is None: raise ValueError("db_uri not supplied") @@ -61,7 +61,7 @@ def _parse_kwargs(self, **kwargs: Any) -> dict[Any, Any]: kwargs["connectTimeoutMS"] = 5000 return kwargs - def __repr__(self): + def __repr__(self) -> str: return "".format( self.__class__.__name__, getattr(self, "sanitized_uri", None), getattr(self, "_database_name", None) ) diff --git a/src/eduid/userdb/db/sync_db.py b/src/eduid/userdb/db/sync_db.py index a34dfe9c9..f151334e5 100644 --- a/src/eduid/userdb/db/sync_db.py +++ b/src/eduid/userdb/db/sync_db.py @@ -47,7 +47,7 @@ def __init__( db_uri: str, db_name: str | None = None, **kwargs: Any, - ): + ) -> None: super().__init__(db_uri=db_uri, db_name=db_name, **kwargs) try: self._client = MongoClientCache().get_client(db=self) @@ -88,7 +88,7 @@ def get_collection( _db = self.get_database(database_name) return _db[collection] - def is_healthy(self): + def is_healthy(self) -> bool: """ From mongo_client.py: Starting with version 3.0 the :class:`MongoClient` @@ -119,7 +119,7 @@ def is_healthy(self): logger.error(f"{self} not healthy: {e}") return False - def close(self): + def close(self) -> None: self._client.close() @@ -138,7 +138,7 @@ class SaveResult: updated: int = 0 doc_id: ObjectId | None = None - def __bool__(self): + def __bool__(self) -> bool: return bool(self.inserted or self.updated) @@ -151,7 +151,7 @@ def __init__( db_name: str, collection: str, safe_writes: bool = False, - ): + ) -> None: self._db_uri = db_uri self._coll_name = collection self._db = MongoDB(db_uri, db_name=db_name) @@ -159,12 +159,12 @@ def __init__( if safe_writes: self._coll = self._coll.with_options(write_concern=pymongo.WriteConcern(w="majority")) - def __repr__(self): + def __repr__(self) -> str: return f"" __str__ = __repr__ - def _drop_whole_collection(self): + def _drop_whole_collection(self) -> None: """ Drop the whole collection. Should ONLY be used in testing, obviously. :return: @@ -182,7 +182,7 @@ def _get_all_docs(self) -> pymongo.cursor.Cursor[TUserDbDocument]: """ return self._coll.find({}) - def _get_document_by_attr(self, attr: str, value: Any) -> TUserDbDocument | None: + def _get_document_by_attr(self, attr: str, value: Any) -> TUserDbDocument | None: # noqa: ANN401 """ Return the document in the MongoDB matching field=value diff --git a/src/eduid/userdb/element.py b/src/eduid/userdb/element.py index 758de0024..7a4eb3530 100644 --- a/src/eduid/userdb/element.py +++ b/src/eduid/userdb/element.py @@ -205,7 +205,7 @@ class VerifiedElement(Element, ABC): proofing_method: Enum | None = None proofing_version: str | None = None - def __str__(self): + def __str__(self) -> str: return f"" @classmethod @@ -245,7 +245,7 @@ class PrimaryElement(VerifiedElement, ABC): is_primary: bool = Field(default=False, alias="primary") # primary is the old name - def __setattr__(self, key: str, value: Any): + def __setattr__(self, key: str, value: object) -> None: """ raise PrimaryElementViolation when trying to set a primary element as unverified """ @@ -254,7 +254,7 @@ def __setattr__(self, key: str, value: Any): super().__setattr__(key, value) - def __str__(self): + def __str__(self) -> str: return ( f"" @@ -274,6 +274,7 @@ def _to_dict_transform(self, data: dict[str, Any]) -> dict[str, Any]: ListElement = TypeVar("ListElement", bound=Element) MatchingElement = TypeVar("MatchingElement", bound=Element) +TElementList = TypeVar("TElementList", bound="ElementList") class ElementList(BaseModel, Generic[ListElement], ABC): @@ -286,7 +287,7 @@ class ElementList(BaseModel, Generic[ListElement], ABC): elements: list[ListElement] = Field(default=[]) model_config = ConfigDict(validate_assignment=True, extra="forbid") - def __str__(self): + def __str__(self) -> str: return "".format(self.__class__.__name__, getattr(self, "elements", None)) @field_validator("elements", mode="before") @@ -312,7 +313,7 @@ def _validate_elements(cls, values: list[ListElement]) -> list[ListElement]: return values @classmethod - def from_list_of_dicts(cls, items): + def from_list_of_dicts(cls: type[TElementList], items: list[dict[str, Any]]) -> TElementList: # must be implemented by subclass to get correct type information raise NotImplementedError() @@ -349,7 +350,7 @@ def find(self, key: ElementKey | str | None) -> ListElement | None: raise EduIDUserDBError("More than one element found") return res[0] - def add(self, element: ListElement): + def add(self, element: ListElement) -> None: """ Add an element to the list. @@ -421,7 +422,7 @@ class PrimaryElementList(VerifiedElementList[ListElement], Generic[ListElement], """ @classmethod - def _validate_elements(cls, values: list[ListElement]): + def _validate_elements(cls, values: list[ListElement]) -> list[ListElement]: """ Validate elements. Since the 'elements' property is defined on subclasses (to get the right type information), a pydantic validator can't be placed here diff --git a/src/eduid/userdb/event.py b/src/eduid/userdb/event.py index b995839a7..728f4e8ca 100644 --- a/src/eduid/userdb/event.py +++ b/src/eduid/userdb/event.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 from pydantic import Field @@ -11,6 +11,9 @@ TEventSubclass = TypeVar("TEventSubclass", bound="Event") +if TYPE_CHECKING: + from eduid.userdb.tou import ToUEvent + class Event(Element): """ """ @@ -69,7 +72,7 @@ class EventList(ElementList[ListElement], Generic[ListElement], ABC): pass -def event_from_dict(data: dict[str, Any]): +def event_from_dict(data: dict[str, Any]) -> ToUEvent: """ Create an Event instance (probably really a subclass of Event) from a dict. diff --git a/src/eduid/userdb/exceptions.py b/src/eduid/userdb/exceptions.py index c411f80f8..0cf281fc9 100644 --- a/src/eduid/userdb/exceptions.py +++ b/src/eduid/userdb/exceptions.py @@ -2,8 +2,6 @@ Exceptions thrown by the eduid.userdb database lookup functions. """ -from typing import Any - class EduIDDBError(Exception): """ @@ -13,11 +11,11 @@ class EduIDDBError(Exception): :type reason: object """ - def __init__(self, reason: Any): + def __init__(self, reason: object) -> None: Exception.__init__(self) self.reason = reason - def __str__(self): + def __str__(self) -> str: return f"<{self.__class__.__name__} instance at {hex(id(self))}: {self.reason!r}>" diff --git a/src/eduid/userdb/group_management/db.py b/src/eduid/userdb/group_management/db.py index 7e3ecc51e..2aff4f375 100644 --- a/src/eduid/userdb/group_management/db.py +++ b/src/eduid/userdb/group_management/db.py @@ -9,7 +9,9 @@ class GroupManagementInviteStateDB(BaseDB): - def __init__(self, db_uri: str, db_name: str = "eduid_group_management", collection: str = "group_invite_data"): + def __init__( + self, db_uri: str, db_name: str = "eduid_group_management", collection: str = "group_invite_data" + ) -> None: super().__init__(db_uri, db_name, collection=collection) # Create an index so that invites for group_scim_id, email_address and role is unique indexes = { diff --git a/src/eduid/userdb/idp/db.py b/src/eduid/userdb/idp/db.py index 8dd243fde..6e06a9d51 100644 --- a/src/eduid/userdb/idp/db.py +++ b/src/eduid/userdb/idp/db.py @@ -13,7 +13,7 @@ class IdPUserDb(UserDB[IdPUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_am", collection: str = "attributes"): + def __init__(self, db_uri: str, db_name: str = "eduid_am", collection: str = "attributes") -> None: super().__init__(db_uri, db_name, collection=collection) @classmethod diff --git a/src/eduid/userdb/locked_identity.py b/src/eduid/userdb/locked_identity.py index dad1b1293..4a68c7e1e 100644 --- a/src/eduid/userdb/locked_identity.py +++ b/src/eduid/userdb/locked_identity.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, NoReturn from pydantic import field_validator +from eduid.userdb.element import ElementKey from eduid.userdb.exceptions import EduIDUserDBError from eduid.userdb.identity import IdentityElement, IdentityList @@ -17,7 +18,7 @@ class LockedIdentityList(IdentityList): @field_validator("elements") @classmethod - def validate_is_verified(cls, v: list[IdentityElement]): + def validate_is_verified(cls, v: list[IdentityElement]) -> list[IdentityElement]: # If using a validator with a subclass that references a List type field on a parent class, using # each_item=True will cause the validator not to run; instead, the list must be iterated over programmatically. if not all([item.is_verified for item in v]): @@ -34,7 +35,7 @@ def replace(self, element: IdentityElement) -> None: self.add(element=element) return None - def remove(self, key): + def remove(self, key: ElementKey) -> NoReturn: """ Override remove method as an element should be set once, remove never. """ diff --git a/src/eduid/userdb/logs/db.py b/src/eduid/userdb/logs/db.py index 871e9c0f7..4d2a2a093 100644 --- a/src/eduid/userdb/logs/db.py +++ b/src/eduid/userdb/logs/db.py @@ -11,7 +11,7 @@ class LogDB(BaseDB): - def __init__(self, db_uri: str, collection: str): + def __init__(self, db_uri: str, collection: str) -> None: db_name = "eduid_logs" # Make sure writes reach a majority of replicas super().__init__(db_uri, db_name, collection, safe_writes=True) @@ -52,7 +52,7 @@ def exists(self, authenticator_id: str | UUID, last_status_change: datetime) -> class UserChangeLog(LogDB): - def __init__(self, db_uri: str, collection: str = "user_change_log"): + def __init__(self, db_uri: str, collection: str = "user_change_log") -> None: super().__init__(db_uri, collection) def get_by_eppn(self, eppn: str) -> list[UserChangeLogElement]: @@ -61,7 +61,7 @@ def get_by_eppn(self, eppn: str) -> list[UserChangeLogElement]: class ManagedAccountLog(LogDB): - def __init__(self, db_uri: str, collection: str = "managed_account_log"): + def __init__(self, db_uri: str, collection: str = "managed_account_log") -> None: super().__init__(db_uri, collection) # Create in index indexes = {"expiration-time": {"key": [("expire_at", 1)], "expireAfterSeconds": 0}} diff --git a/src/eduid/userdb/maccapi/userdb.py b/src/eduid/userdb/maccapi/userdb.py index f7ddb5936..d0812506f 100644 --- a/src/eduid/userdb/maccapi/userdb.py +++ b/src/eduid/userdb/maccapi/userdb.py @@ -40,7 +40,7 @@ def to_idp_user(self) -> IdPUser: class ManagedAccountDB(UserDB[ManagedAccount]): - def __init__(self, db_uri: str, db_name: str = "eduid_managed_accounts", collection: str = "users"): + def __init__(self, db_uri: str, db_name: str = "eduid_managed_accounts", collection: str = "users") -> None: super().__init__(db_uri, db_name, collection) indexes = { diff --git a/src/eduid/userdb/mail.py b/src/eduid/userdb/mail.py index c35316c40..886bdddba 100644 --- a/src/eduid/userdb/mail.py +++ b/src/eduid/userdb/mail.py @@ -14,7 +14,7 @@ class MailAddress(PrimaryElement): @field_validator("email", mode="before") @classmethod - def validate_email(cls, v): + def validate_email(cls, v: object) -> str: if not isinstance(v, str): raise ValueError("must be a string") return v.lower() @@ -53,7 +53,7 @@ def from_list_of_dicts(cls: type[MailAddressList], items: list[dict[str, Any]]) return cls(elements=[MailAddress.from_dict(this) for this in items]) -def address_from_dict(data): +def address_from_dict(data: dict[str, Any]) -> MailAddress: """ Create a MailAddress instance from a dict. diff --git a/src/eduid/userdb/meta.py b/src/eduid/userdb/meta.py index dd076aab2..b963a9fc0 100644 --- a/src/eduid/userdb/meta.py +++ b/src/eduid/userdb/meta.py @@ -24,5 +24,5 @@ class Meta(BaseModel): is_in_database: Annotated[bool, Field(exclude=True)] = False # this is set to True when userdb loads the object model_config = ConfigDict(arbitrary_types_allowed=True) - def new_version(self): + def new_version(self) -> None: self.version = ObjectId() diff --git a/src/eduid/userdb/personal_data/db.py b/src/eduid/userdb/personal_data/db.py index 7ea3270db..5d5240543 100644 --- a/src/eduid/userdb/personal_data/db.py +++ b/src/eduid/userdb/personal_data/db.py @@ -10,7 +10,7 @@ class PersonalDataUserDB(UserDB[PersonalDataUser]): - def __init__(self, db_uri, db_name="eduid_personal_data", collection="profiles"): + def __init__(self, db_uri: str, db_name: str = "eduid_personal_data", collection: str = "profiles") -> None: super().__init__(db_uri, db_name, collection=collection) @classmethod diff --git a/src/eduid/userdb/proofing/db.py b/src/eduid/userdb/proofing/db.py index 6b3c705ea..ed33a4388 100644 --- a/src/eduid/userdb/proofing/db.py +++ b/src/eduid/userdb/proofing/db.py @@ -26,7 +26,7 @@ class ProofingStateDB(BaseDB, Generic[ProofingStateVar], ABC): - def __init__(self, db_uri: str, db_name: str, collection: str = "proofing_data"): + def __init__(self, db_uri: str, db_name: str, collection: str = "proofing_data") -> None: super().__init__(db_uri, db_name, collection) @classmethod @@ -97,7 +97,7 @@ def remove_state(self, state: ProofingStateVar) -> None: class LetterProofingStateDB(ProofingStateDB[LetterProofingState]): - def __init__(self, db_uri: str, db_name: str = "eduid_idproofing_letter"): + def __init__(self, db_uri: str, db_name: str = "eduid_idproofing_letter") -> None: super().__init__(db_uri, db_name) @classmethod @@ -106,7 +106,7 @@ def state_from_dict(cls, data: Mapping[str, Any]) -> LetterProofingState: class EmailProofingStateDB(ProofingStateDB[EmailProofingState]): - def __init__(self, db_uri: str, db_name: str = "eduid_email"): + def __init__(self, db_uri: str, db_name: str = "eduid_email") -> None: super().__init__(db_uri, db_name) @classmethod @@ -136,7 +136,7 @@ def remove_state(self, state: ProofingStateVar) -> None: class PhoneProofingStateDB(ProofingStateDB[PhoneProofingState]): - def __init__(self, db_uri: str, db_name: str = "eduid_phone"): + def __init__(self, db_uri: str, db_name: str = "eduid_phone") -> None: super().__init__(db_uri, db_name) @classmethod @@ -189,7 +189,7 @@ def get_state_by_oidc_state(self, oidc_state: str) -> ProofingStateVar | None: class OidcProofingStateDB(OidcStateDB[OidcProofingState]): - def __init__(self, db_uri: str, db_name: str = "eduid_oidc_proofing"): + def __init__(self, db_uri: str, db_name: str = "eduid_oidc_proofing") -> None: super().__init__(db_uri, db_name) @classmethod @@ -200,7 +200,7 @@ def state_from_dict(cls, data: Mapping[str, Any]) -> OidcProofingState: class OrcidProofingStateDB(OidcStateDB[OrcidProofingState]): ProofingStateClass = OrcidProofingState - def __init__(self, db_uri: str, db_name: str = "eduid_orcid"): + def __init__(self, db_uri: str, db_name: str = "eduid_orcid") -> None: super().__init__(db_uri, db_name) @classmethod @@ -209,7 +209,7 @@ def state_from_dict(cls, data: Mapping[str, Any]) -> OrcidProofingState: class ProofingUserDB(UserDB[ProofingUser]): - def __init__(self, db_uri: str, db_name: str, collection: str = "profiles"): + def __init__(self, db_uri: str, db_name: str, collection: str = "profiles") -> None: super().__init__(db_uri, db_name, collection=collection) def save(self, user: ProofingUser) -> UserSaveResult: @@ -221,55 +221,55 @@ def user_from_dict(cls, data: TUserDbDocument) -> ProofingUser: class LetterProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_idproofing_letter"): + def __init__(self, db_uri: str, db_name: str = "eduid_idproofing_letter") -> None: super().__init__(db_uri, db_name) class OidcProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_oidc_proofing"): + def __init__(self, db_uri: str, db_name: str = "eduid_oidc_proofing") -> None: super().__init__(db_uri, db_name) class PhoneProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_phone"): + def __init__(self, db_uri: str, db_name: str = "eduid_phone") -> None: super().__init__(db_uri, db_name) class EmailProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_email"): + def __init__(self, db_uri: str, db_name: str = "eduid_email") -> None: super().__init__(db_uri, db_name) class LookupMobileProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_lookup_mobile_proofing"): + def __init__(self, db_uri: str, db_name: str = "eduid_lookup_mobile_proofing") -> None: super().__init__(db_uri, db_name) class OrcidProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_orcid"): + def __init__(self, db_uri: str, db_name: str = "eduid_orcid") -> None: super().__init__(db_uri, db_name) class EidasProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_eidas"): + def __init__(self, db_uri: str, db_name: str = "eduid_eidas") -> None: super().__init__(db_uri, db_name) class LadokProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_ladok"): + def __init__(self, db_uri: str, db_name: str = "eduid_ladok") -> None: super().__init__(db_uri, db_name) class SvideIDProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_svipe_id"): + def __init__(self, db_uri: str, db_name: str = "eduid_svipe_id") -> None: super().__init__(db_uri, db_name) class BankIDProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_bankid"): + def __init__(self, db_uri: str, db_name: str = "eduid_bankid") -> None: super().__init__(db_uri, db_name) class FrejaEIDProofingUserDB(ProofingUserDB): - def __init__(self, db_uri: str, db_name: str = "eduid_freja_eid"): + def __init__(self, db_uri: str, db_name: str = "eduid_freja_eid") -> None: super().__init__(db_uri, db_name) diff --git a/src/eduid/userdb/proofing/element.py b/src/eduid/userdb/proofing/element.py index 4ce26a902..a0e66a525 100644 --- a/src/eduid/userdb/proofing/element.py +++ b/src/eduid/userdb/proofing/element.py @@ -86,7 +86,7 @@ class EmailProofingElement(ProofingElement): @field_validator("email", mode="before") @classmethod - def validate_email(cls, v: Any): + def validate_email(cls, v: object) -> str: if not isinstance(v, str): raise ValueError("must be a string") return v.lower() diff --git a/src/eduid/userdb/proofing/state.py b/src/eduid/userdb/proofing/state.py index 310cf3512..c459d08c0 100644 --- a/src/eduid/userdb/proofing/state.py +++ b/src/eduid/userdb/proofing/state.py @@ -5,7 +5,7 @@ import logging from collections.abc import Mapping from dataclasses import asdict, dataclass -from typing import Any +from typing import Any, TypeVar import bson @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) +TProofingState = TypeVar("TProofingState", bound="ProofingState") + @dataclass() class ProofingState: @@ -33,12 +35,12 @@ class ProofingState: # None if ProofingState has never been written to the database. modified_ts: datetime.datetime | None - def __post_init__(self): + def __post_init__(self) -> None: if self.id is None: self.id = bson.ObjectId() @classmethod - def _default_from_dict(cls, data: Mapping[str, Any], fields: set[str]): + def _default_from_dict(cls: type[TProofingState], data: Mapping[str, Any], fields: set[str]) -> TProofingState: _data = copy.deepcopy(dict(data)) # to not modify callers data if "eduPersonPrincipalName" in _data: _data["eppn"] = _data.pop("eduPersonPrincipalName") @@ -60,7 +62,7 @@ def _default_from_dict(cls, data: Mapping[str, Any], fields: set[str]): return cls(**_data) @classmethod - def from_dict(cls, data: Mapping[str, Any]): + def from_dict(cls, data: Mapping[str, Any]) -> Any: # noqa: ANN401 raise NotImplementedError(f"from_dict not implemented for class {cls.__name__}") def to_dict(self) -> TUserDbDocument: @@ -71,7 +73,7 @@ def to_dict(self) -> TUserDbDocument: res["modified_ts"] = datetime.datetime.utcnow() return TUserDbDocument(res) - def __str__(self): + def __str__(self) -> str: return f"" @property diff --git a/src/eduid/userdb/reset_password/db.py b/src/eduid/userdb/reset_password/db.py index 659296b9e..57897aedc 100644 --- a/src/eduid/userdb/reset_password/db.py +++ b/src/eduid/userdb/reset_password/db.py @@ -16,7 +16,7 @@ class ResetPasswordUserDB(UserDB[ResetPasswordUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_reset_password", collection: str = "profiles"): + def __init__(self, db_uri: str, db_name: str = "eduid_reset_password", collection: str = "profiles") -> None: super().__init__(db_uri, db_name, collection=collection) @classmethod @@ -25,7 +25,9 @@ def user_from_dict(cls, data: TUserDbDocument) -> ResetPasswordUser: class ResetPasswordStateDB(BaseDB): - def __init__(self, db_uri: str, db_name: str = "eduid_reset_password", collection: str = "password_reset_data"): + def __init__( + self, db_uri: str, db_name: str = "eduid_reset_password", collection: str = "password_reset_data" + ) -> None: super().__init__(db_uri, db_name, collection=collection) def get_state_by_email_code( diff --git a/src/eduid/userdb/reset_password/state.py b/src/eduid/userdb/reset_password/state.py index 695e23d3d..e6cd2d03d 100644 --- a/src/eduid/userdb/reset_password/state.py +++ b/src/eduid/userdb/reset_password/state.py @@ -31,10 +31,10 @@ class ResetPasswordState: extra_security: dict[str, Any] | None = None generated_password: bool = False - def __post_init__(self): + def __post_init__(self) -> None: self.reference = str(self.id) - def __str__(self): + def __str__(self) -> str: return f"" def to_dict(self) -> TUserDbDocument: @@ -83,12 +83,12 @@ class ResetPasswordEmailState(ResetPasswordState, _ResetPasswordEmailStateRequir email_reference: str = field(default_factory=lambda: str(uuid4())) - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() self.method = "email" self.email_code = CodeElement.parse(application="security", code_or_element=self.email_code) - def to_dict(self): + def to_dict(self) -> TUserDbDocument: res = super().to_dict() res["email_code"] = self.email_code.to_dict() return res @@ -106,7 +106,7 @@ class _ResetPasswordEmailAndPhoneStateRequired: class ResetPasswordEmailAndPhoneState(ResetPasswordEmailState, _ResetPasswordEmailAndPhoneStateRequired): """ """ - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() self.method = "email_and_phone" self.phone_code = CodeElement.parse(application="security", code_or_element=self.phone_code) diff --git a/src/eduid/userdb/scimapi/eventdb.py b/src/eduid/userdb/scimapi/eventdb.py index 3daf0f4ad..4e34b0031 100644 --- a/src/eduid/userdb/scimapi/eventdb.py +++ b/src/eduid/userdb/scimapi/eventdb.py @@ -87,7 +87,7 @@ def from_dict(cls: type[ScimApiEvent], data: Mapping[str, Any]) -> ScimApiEvent: class ScimApiEventDB(ScimApiBaseDB): - def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_scimapi"): + def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_scimapi") -> None: super().__init__(db_uri, db_name, collection=collection) indexes = { # Remove messages older than expires_at datetime diff --git a/src/eduid/userdb/scimapi/groupdb.py b/src/eduid/userdb/scimapi/groupdb.py index 3b65792e9..8a636f42a 100644 --- a/src/eduid/userdb/scimapi/groupdb.py +++ b/src/eduid/userdb/scimapi/groupdb.py @@ -51,7 +51,7 @@ class ScimApiGroup(ScimApiResourceBase, _ScimApiGroupRequired): extensions: GroupExtensions = field(default_factory=lambda: GroupExtensions()) graph: GraphGroup = field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.graph = GraphGroup(identifier=str(self.scim_id), display_name=self.display_name) @property @@ -59,7 +59,7 @@ def members(self) -> set[GraphGroup | GraphUser]: return self.graph.members @members.setter - def members(self, members: Iterable[GraphGroup | GraphUser]): + def members(self, members: Iterable[GraphGroup | GraphUser]) -> None: members = set(members) self.graph = replace(self.graph, members=members) @@ -71,7 +71,7 @@ def owners(self) -> set[GraphGroup | GraphUser]: return self.graph.owners @owners.setter - def owners(self, owners: Iterable[GraphGroup | GraphUser]): + def owners(self, owners: Iterable[GraphGroup | GraphUser]) -> None: owners = set(owners) self.graph = replace(self.graph, owners=owners) @@ -110,7 +110,7 @@ def __init__( mongo_collection: str, neo4j_config: dict[str, Any] | None = None, setup_indexes: bool = True, - ): + ) -> None: super().__init__(mongo_uri, mongo_dbname, collection=mongo_collection) self.graphdb = GroupDB(db_uri=neo4j_uri, scope=scope, config=neo4j_config) logger.info(f"{self} initialised") @@ -268,7 +268,9 @@ def get_group_by_display_name(self, display_name: str) -> ScimApiGroup | None: return group return None - def get_groups_by_property(self, key: str, value: str | int, skip=0, limit=100) -> tuple[list[ScimApiGroup], int]: + def get_groups_by_property( + self, key: str, value: str | int, skip: int = 0, limit: int = 100 + ) -> tuple[list[ScimApiGroup], int]: docs, count = self._get_documents_and_count_by_filter({key: value}, skip=skip, limit=limit) if not docs: return [], 0 diff --git a/src/eduid/userdb/scimapi/invitedb.py b/src/eduid/userdb/scimapi/invitedb.py index b72e58475..181cd6b84 100644 --- a/src/eduid/userdb/scimapi/invitedb.py +++ b/src/eduid/userdb/scimapi/invitedb.py @@ -70,7 +70,7 @@ def from_dict(cls: type[ScimApiInvite], data: Mapping[str, Any]) -> ScimApiInvit class ScimApiInviteDB(ScimApiBaseDB): - def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_scimapi"): + def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_scimapi") -> None: super().__init__(db_uri, db_name, collection=collection) # Create an index so that scim_id is unique per data owner indexes = { @@ -107,7 +107,7 @@ def save(self, invite: ScimApiInvite) -> bool: return result.acknowledged - def remove(self, invite: ScimApiInvite): + def remove(self, invite: ScimApiInvite) -> bool: return self.remove_document(invite.invite_id) def get_invite_by_scim_id(self, scim_id: str) -> ScimApiInvite | None: diff --git a/src/eduid/userdb/scimapi/userdb.py b/src/eduid/userdb/scimapi/userdb.py index 396904a85..20a25f37c 100644 --- a/src/eduid/userdb/scimapi/userdb.py +++ b/src/eduid/userdb/scimapi/userdb.py @@ -40,7 +40,7 @@ class ScimApiUser(ScimApiResourceBase): linked_accounts: list[ScimApiLinkedAccount] = field(default_factory=list) @property - def etag(self): + def etag(self) -> str: return f'W/"{self.version}"' def to_dict(self) -> TUserDbDocument: @@ -72,7 +72,9 @@ def from_dict(cls: type[ScimApiUser], data: Mapping[str, Any]) -> ScimApiUser: class ScimApiUserDB(ScimApiBaseDB): - def __init__(self, db_uri: str, collection: str, db_name="eduid_scimapi", setup_indexes: bool = True): + def __init__( + self, db_uri: str, collection: str, db_name: str = "eduid_scimapi", setup_indexes: bool = True + ) -> None: super().__init__(db_uri, db_name, collection=collection) if setup_indexes: # Create an index so that scim_id and external_id is unique per data owner @@ -128,7 +130,7 @@ def save(self, user: ScimApiUser) -> None: return None - def remove(self, user: ScimApiUser): + def remove(self, user: ScimApiUser) -> bool: return self.remove_document(user.user_id) def get_user_by_scim_id(self, scim_id: str) -> ScimApiUser | None: @@ -157,7 +159,7 @@ def get_user_by_profile_data( profile: str, key: str, operator: str, - value: datetime, + value: str | int, limit: int | None = None, skip: int | None = None, ) -> tuple[list[ScimApiUser], int]: @@ -174,7 +176,7 @@ def user_exists(self, scim_id: str) -> bool: class ScimEduidUserDB(UserDB[User]): """EduID userdb""" - def __init__(self, db_uri: str, db_name: str = "eduid_scimapi"): + def __init__(self, db_uri: str, db_name: str = "eduid_scimapi") -> None: super().__init__(db_uri, db_name) @classmethod diff --git a/src/eduid/userdb/security/db.py b/src/eduid/userdb/security/db.py index 54e27ffe4..1c6cc5edd 100644 --- a/src/eduid/userdb/security/db.py +++ b/src/eduid/userdb/security/db.py @@ -10,7 +10,7 @@ class SecurityUserDB(UserDB[SecurityUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_security", collection: str = "profiles"): + def __init__(self, db_uri: str, db_name: str = "eduid_security", collection: str = "profiles") -> None: super().__init__(db_uri, db_name, collection=collection) @classmethod diff --git a/src/eduid/userdb/signup/invite.py b/src/eduid/userdb/signup/invite.py index 50e7d2b28..599c24d93 100644 --- a/src/eduid/userdb/signup/invite.py +++ b/src/eduid/userdb/signup/invite.py @@ -34,7 +34,7 @@ class InviteMailAddress: email: str primary: bool - def __post_init__(self): + def __post_init__(self) -> None: # Make sure email is lowercase on init as we had trouble with mixed case super().__setattr__("email", self.email.lower()) diff --git a/src/eduid/userdb/signup/invitedb.py b/src/eduid/userdb/signup/invitedb.py index 151cdd1b7..4745b63d0 100644 --- a/src/eduid/userdb/signup/invitedb.py +++ b/src/eduid/userdb/signup/invitedb.py @@ -15,7 +15,7 @@ class SignupInviteDB(BaseDB): - def __init__(self, db_uri: str, db_name: str = "eduid_signup", collection: str = "invites"): + def __init__(self, db_uri: str, db_name: str = "eduid_signup", collection: str = "invites") -> None: BaseDB.__init__(self, db_uri, db_name, collection) # Create an index so that invite_code is unique and invites are removed at the expires_at time indexes = { diff --git a/src/eduid/userdb/signup/userdb.py b/src/eduid/userdb/signup/userdb.py index f060458f3..03bd1d60c 100644 --- a/src/eduid/userdb/signup/userdb.py +++ b/src/eduid/userdb/signup/userdb.py @@ -14,7 +14,7 @@ def __init__( db_name: str = "eduid_signup", collection: str = "registered", auto_expire: timedelta | None = None, - ): + ) -> None: super().__init__(db_uri, db_name, collection=collection) if auto_expire is not None: diff --git a/src/eduid/userdb/support/db.py b/src/eduid/userdb/support/db.py index 37fb3fe40..5d3ce4bda 100644 --- a/src/eduid/userdb/support/db.py +++ b/src/eduid/userdb/support/db.py @@ -20,7 +20,7 @@ class SupportUserDB(UserDB[SupportUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_am", collection: str = "attributes"): + def __init__(self, db_uri: str, db_name: str = "eduid_am", collection: str = "attributes") -> None: super().__init__(db_uri, db_name, collection=collection) @classmethod @@ -53,7 +53,7 @@ class SupportSignupUserDB(SignupUserDB): class SupportAuthnInfoDB(BaseDB): model = models.UserAuthnInfo - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_idp_authninfo" collection = "authn_info" super().__init__(db_uri, db_name, collection) @@ -86,7 +86,7 @@ def get_credential_info(self, credential_id: str) -> dict[str, Any]: class SupportProofingDB(BaseDB): model: type[GenericFilterDict] = GenericFilterDict - def __init__(self, db_uri: str, db_name: str, collection: str): + def __init__(self, db_uri: str, db_name: str, collection: str) -> None: super().__init__(db_uri, db_name, collection) def get_proofing_state(self, eppn: str) -> dict[str, Any]: @@ -111,7 +111,7 @@ def get_proofing_states(self, eppn: str) -> list[dict[str, Any]]: class SupportLetterProofingDB(SupportProofingDB): model = models.UserLetterProofing - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_idproofing_letter" collection = "proofing_data" super().__init__(db_uri, db_name, collection) @@ -131,7 +131,7 @@ def get_proofing_state(self, eppn: str) -> dict[str, Any]: class SupportOidcProofingDB(SupportProofingDB): model = models.UserOidcProofing - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_oidc_proofing" collection = "proofing_data" super().__init__(db_uri, db_name, collection) @@ -140,7 +140,7 @@ def __init__(self, db_uri: str): class SupportEmailProofingDB(SupportProofingDB): model = models.UserEmailProofing - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_email" collection = "proofing_data" super().__init__(db_uri, db_name, collection) @@ -149,7 +149,7 @@ def __init__(self, db_uri: str): class SupportPhoneProofingDB(SupportProofingDB): model = models.UserPhoneProofing - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_phone" collection = "proofing_data" super().__init__(db_uri, db_name, collection) @@ -158,7 +158,7 @@ def __init__(self, db_uri: str): class SupportProofingLogDB(BaseDB): model = models.ProofingLogEntry - def __init__(self, db_uri: str): + def __init__(self, db_uri: str) -> None: db_name = "eduid_logs" collection = "proofing_log" super().__init__(db_uri, db_name, collection) diff --git a/src/eduid/userdb/support/models.py b/src/eduid/userdb/support/models.py index a88f585d7..5022a9327 100644 --- a/src/eduid/userdb/support/models.py +++ b/src/eduid/userdb/support/models.py @@ -1,4 +1,7 @@ from copy import deepcopy +from typing import Any + +from eduid.userdb.db.base import TUserDbDocument __author__ = "lundberg" @@ -8,7 +11,7 @@ class GenericFilterDict(dict): add_keys: list[str] | None = None remove_keys: list[str] | None = None - def __init__(self, data): + def __init__(self, data: dict[str, Any] | None) -> None: """ Create a filtered dict with white- or blacklisting of keys @@ -35,7 +38,7 @@ def __init__(self, data): class SupportUserFilter(GenericFilterDict): remove_keys = ["_id", "letter_proofing_data"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) super().__init__(_data) @@ -47,7 +50,7 @@ def __init__(self, data): class SupportSignupUserFilter(GenericFilterDict): remove_keys = ["_id", "letter_proofing_data"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) super().__init__(_data) @@ -89,7 +92,7 @@ class ToU(GenericFilterDict): class UserAuthnInfo(GenericFilterDict): add_keys = ["success_ts", "fail_count", "success_count"] - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: _data = deepcopy(data) # Remove months with 0 failures or successes for attrib in ["fail_count", "success_count"]: @@ -110,7 +113,7 @@ class UserActions(GenericFilterDict): class ProofingLogEntry(GenericFilterDict): add_keys = ["verified_data", "created_ts", "proofing_method", "proofing_version", "created_by", "vetting_by"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) # Rename the verified data key to verified_data verified_data_names = ["nin", "mail_address", "phone_number", "orcid"] @@ -129,7 +132,7 @@ class Nin(GenericFilterDict): class ProofingLetter(GenericFilterDict): add_keys = ["sent_ts", "is_sent", "address"] - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: _data = deepcopy(data) super().__init__(_data) self["nin"] = self.Nin(self["nin"]) @@ -142,7 +145,7 @@ class UserOidcProofing(GenericFilterDict): class Nin(GenericFilterDict): add_keys = ["created_ts", "number"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) super().__init__(_data) self["nin"] = self.Nin(self["nin"]) @@ -154,7 +157,7 @@ class UserEmailProofing(GenericFilterDict): class Verification(GenericFilterDict): add_keys = ["created_ts", "email"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) super().__init__(_data) self["verification"] = self.Verification(self["verification"]) @@ -166,7 +169,7 @@ class UserPhoneProofing(GenericFilterDict): class Verification(GenericFilterDict): add_keys = ["created_ts", "number"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument) -> None: _data = deepcopy(data) super().__init__(_data) self["verification"] = self.Verification(self["verification"]) diff --git a/src/eduid/userdb/testing/__init__.py b/src/eduid/userdb/testing/__init__.py index 7627d0b24..1fcf76aec 100644 --- a/src/eduid/userdb/testing/__init__.py +++ b/src/eduid/userdb/testing/__init__.py @@ -8,6 +8,7 @@ import logging.config import unittest from collections.abc import Sequence +from dataclasses import dataclass from typing import Any, cast import pymongo @@ -57,10 +58,10 @@ def conn(self) -> pymongo.MongoClient[TUserDbDocument]: return self._conn @property - def uri(self): + def uri(self) -> str: return f"mongodb://localhost:{self.port}" - def shutdown(self): + def shutdown(self) -> None: if self._conn: logger.info(f"Closing connection {self._conn}") self._conn.close() @@ -72,6 +73,17 @@ def get_instance(cls: type[MongoTemporaryInstance], max_retry_seconds: int = 20) return cast(MongoTemporaryInstance, super().get_instance(max_retry_seconds=max_retry_seconds)) +@dataclass +class SetupConfig: + am_users: list[User] | None = None + am_settings: dict[str, Any] | None = None + want_mongo_uri: bool = True + users: list[str] | None = None + copy_user_to_private: bool = False + init_msg: bool = True + init_lookup_mobile: bool = True + + class MongoTestCase(unittest.TestCase): """TestCase with an embedded MongoDB temporary instance. @@ -82,7 +94,7 @@ class MongoTestCase(unittest.TestCase): A test can access the port using the attribute `port` """ - def setUp(self, *args: list[Any], am_users: list[User] | None = None, **kwargs: dict[str, Any]): + def setUp(self, config: SetupConfig | None = None) -> None: """ Test case initialization. :return: @@ -110,13 +122,15 @@ def setUp(self, *args: list[Any], am_users: list[User] | None = None, **kwargs: else: self.settings.update(mongo_settings) - if am_users: + if config is None: + config = SetupConfig() + if config.am_users: # Set up test users in the MongoDB. - for user in am_users: + for user in config.am_users: logger.debug(f"Adding test user {user} to the database") self.amdb.save(user) - def _init_logging(self): + def _init_logging(self) -> None: local_context = LocalContext( app_debug=True, app_name="testing", @@ -127,7 +141,7 @@ def _init_logging(self): logging_config = make_dictConfig(local_context) logging.config.dictConfig(logging_config) - def _reset_databases(self): + def _reset_databases(self) -> None: """ Reset databases for the next test class. @@ -139,7 +153,7 @@ def _reset_databases(self): self.tmp_db.conn.drop_database(db_name) self.amdb._drop_whole_collection() - def tearDown(self): + def tearDown(self) -> None: for userdoc in self.amdb._get_all_docs(): assert User.from_dict(data=userdoc) self._reset_databases() @@ -156,7 +170,7 @@ class AsyncMongoTestCase(unittest.IsolatedAsyncioTestCase): A test can access the port using the attribute `port` """ - def setUp(self, *args: list[Any], **kwargs: dict[str, Any]): + def setUp(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: """ Test case initialization. :return: @@ -183,7 +197,7 @@ def setUp(self, *args: list[Any], **kwargs: dict[str, Any]): else: self.settings.update(mongo_settings) - def _init_logging(self): + def _init_logging(self) -> None: local_context = LocalContext( app_debug=True, app_name="testing", @@ -194,7 +208,7 @@ def _init_logging(self): logging_config = make_dictConfig(local_context) logging.config.dictConfig(logging_config) - def _reset_databases(self): + def _reset_databases(self) -> None: """ Reset databases for the next test class. @@ -205,6 +219,6 @@ def _reset_databases(self): if db_name not in ["local", "admin", "config"]: # Do not drop mongo internal dbs self.tmp_db.conn.drop_database(db_name) - def tearDown(self): + def tearDown(self) -> None: self._reset_databases() super().tearDown() diff --git a/src/eduid/userdb/testing/temp_instance.py b/src/eduid/userdb/testing/temp_instance.py index faa69f3aa..46661b33e 100644 --- a/src/eduid/userdb/testing/temp_instance.py +++ b/src/eduid/userdb/testing/temp_instance.py @@ -25,7 +25,7 @@ class EduidTemporaryInstance(ABC): _instance = None - def __init__(self, max_retry_seconds: int): + def __init__(self, max_retry_seconds: int) -> None: self._conn: Any | None = None # self._conn should be initialised by subclasses in `setup_conn' self._tmpdir = tempfile.mkdtemp() self._port = random.randint(40000, 65535) @@ -82,7 +82,7 @@ def setup_conn(self) -> bool: @property @abstractmethod - def conn(self) -> Any: + def conn(self) -> Any: # noqa: ANN401 """Return the initialised _conn instance. No default since it ought to be typed in the subclasses.""" raise NotImplementedError("All subclasses of EduidTemporaryInstance should implement the conn property") @@ -106,7 +106,7 @@ def output(self) -> str: _output = "".join(fd.readlines()) return _output - def shutdown(self): + def shutdown(self) -> None: logger.debug(f"{self} output at shutdown:\n{self.output}") if self._process: self._process.terminate() diff --git a/src/eduid/userdb/tests/test_app_user.py b/src/eduid/userdb/tests/test_app_user.py index 155d70325..6e6e0f752 100644 --- a/src/eduid/userdb/tests/test_app_user.py +++ b/src/eduid/userdb/tests/test_app_user.py @@ -12,41 +12,43 @@ class TestAppUser(TestCase): users: UserFixtures user_data: TUserDbDocument - def setUp(self): + def setUp(self) -> None: _users = UserFixtures() self.user = _users.new_user_example self.user_data = self.user.to_dict() - def test_proper_user(self): + def test_proper_user(self) -> None: user = User.from_dict(data=self.user_data) self.assertEqual(user.user_id, self.user_data["_id"]) self.assertEqual(user.eppn, self.user_data["eduPersonPrincipalName"]) - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: user = User(user_id=self.user.user_id, eppn=self.user.eppn, credentials=self.user.credentials) self.assertEqual(user.user_id, self.user.user_id) self.assertEqual(user.eppn, self.user.eppn) - def test_missing_id(self): + def test_missing_id(self) -> None: user = User(eppn=self.user.eppn, credentials=self.user.credentials) self.assertNotEqual(user.user_id, self.user.user_id) - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: _data = self.user.to_dict() _data.pop("eduPersonPrincipalName") with self.assertRaises(ValidationError): User.from_dict(_data) - def test_identities_created_ts_true(self): + def test_identities_created_ts_true(self) -> None: _data = self.user.to_dict() _data["identities"][0]["created_ts"] = True user = User.from_dict(_data) identity = user.identities.find(_data["identities"][0]["identity_type"]) + assert identity assert isinstance(identity.created_ts, datetime) is True - def test_locked_identity_created_ts_true(self): + def test_locked_identity_created_ts_true(self) -> None: _data = self.user.to_dict() _data["locked_identity"][0]["created_ts"] = True user = User.from_dict(_data) locked_identity = user.locked_identity.find(_data["locked_identity"][0]["identity_type"]) + assert locked_identity assert isinstance(locked_identity.created_ts, datetime) is True diff --git a/src/eduid/userdb/tests/test_async_db.py b/src/eduid/userdb/tests/test_async_db.py index b23dc41c5..f92f9fa3f 100644 --- a/src/eduid/userdb/tests/test_async_db.py +++ b/src/eduid/userdb/tests/test_async_db.py @@ -6,7 +6,7 @@ class TestAsyncMongoDB(IsolatedAsyncioTestCase): - async def test_full_uri(self): + async def test_full_uri(self) -> None: # full specified uri uri = "mongodb://db.example.com:1111/testdb" mdb = AsyncMongoDB(uri, db_name="testdb") @@ -17,7 +17,7 @@ async def test_full_uri(self): self.assertEqual(mdb._db_uri, uri) self.assertEqual(mdb._database_name, "testdb") - async def test_uri_without_path_component(self): + async def test_uri_without_path_component(self) -> None: uri = "mongodb://db.example.com:1111" mdb = AsyncMongoDB(uri, db_name="testdb") database = mdb.get_database() @@ -25,7 +25,7 @@ async def test_uri_without_path_component(self): self.assertEqual(mdb._db_uri, uri + "/testdb") self.assertEqual(mdb._database_name, "testdb") - async def test_uri_without_port(self): + async def test_uri_without_port(self) -> None: uri = "mongodb://db.example.com/" mdb = AsyncMongoDB(uri) self.assertEqual(mdb._db_uri, uri) @@ -33,7 +33,7 @@ async def test_uri_without_port(self): assert database is not None self.assertEqual(mdb.sanitized_uri, "mongodb://db.example.com/") - async def test_uri_with_username_and_password(self): + async def test_uri_with_username_and_password(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:1111/testdb" mdb = AsyncMongoDB(uri, db_name="testdb") conn = mdb.get_connection() @@ -47,7 +47,7 @@ async def test_uri_with_username_and_password(self): mdb.__repr__(), "" ) - async def test_uri_with_replicaset(self): + async def test_uri_with_replicaset(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com,db2.example.com:27017,db3.example.com:1234/?replicaSet=rs9" mdb = AsyncMongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9") @@ -56,7 +56,7 @@ async def test_uri_with_replicaset(self): "mongodb://john:s3cr3t@db.example.com,db2.example.com,db3.example.com:1234/testdb?replicaset=rs9", ) - async def test_uri_with_options(self): + async def test_uri_with_options(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:27017/?ssl=true&replicaSet=rs9" mdb = AsyncMongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9&tls=true") @@ -71,25 +71,25 @@ async def asyncSetUp(self) -> None: self.num_objs = 10 await self.db.collection.insert_many([{"x": i} for i in range(self.num_objs)]) - async def test_db_count(self): + async def test_db_count(self) -> None: self.assertEqual(self.num_objs, await self.db.db_count()) - async def test_db_count_limit(self): + async def test_db_count_limit(self) -> None: self.assertEqual(1, await self.db.db_count(limit=1)) self.assertEqual(2, await self.db.db_count(limit=2)) - async def test_db_count_spec(self): + async def test_db_count_spec(self) -> None: self.assertEqual(1, await self.db.db_count(spec={"x": 3})) - async def test_get_documents_by_filter_skip(self): + async def test_get_documents_by_filter_skip(self) -> None: docs = await self.db._get_documents_by_filter(spec={}, skip=2) self.assertEqual(8, len(docs)) - async def test_get_documents_by_filter_limit(self): + async def test_get_documents_by_filter_limit(self) -> None: docs = await self.db._get_documents_by_filter(spec={}, limit=1) self.assertEqual(1, len(docs)) - async def test_get_documents_by_aggregate(self): + async def test_get_documents_by_aggregate(self) -> None: match = { "x": 3, } diff --git a/src/eduid/userdb/tests/test_credentials.py b/src/eduid/userdb/tests/test_credentials.py index 35c01d8b2..3bbde0c47 100644 --- a/src/eduid/userdb/tests/test_credentials.py +++ b/src/eduid/userdb/tests/test_credentials.py @@ -46,12 +46,12 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]) -> str: return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["public_key"].encode("utf-8")).hexdigest() class TestCredentialList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # make pytest always show full diffs self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) @@ -59,14 +59,14 @@ def setUp(self): self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) self.four = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict, _four_dict]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) self.assertEqual(4, len(self.four.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) expected = [_one_dict] @@ -74,22 +74,23 @@ def test_to_list_of_dicts(self): assert obtained == expected, "Credential list with one password not as expected" - def test_find(self): + def test_find(self) -> None: match = self.two.find("222222222222222222222222") - self.assertIsInstance(match, Password) + assert isinstance(match, Password) self.assertEqual(match.credential_id, "222222222222222222222222") self.assertEqual(match.salt, "secondPasswordElement") self.assertEqual(match.created_by, "test") - def test_filter(self): + def test_filter(self) -> None: match = self.four.filter(U2F) assert len(match) == 1 token = match[0] assert token.key == _keyid(_four_dict) assert token.public_key == "foo" - def test_add(self): + def test_add(self) -> None: second = self.two.find(str(ObjectId("222222222222222222222222"))) + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -97,8 +98,9 @@ def test_add(self): assert obtained == expected, "List of credentials with added credential different than expected" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: dup = self.two.find(str(ObjectId("222222222222222222222222"))) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -113,15 +115,15 @@ def test_add_duplicate(self): ], ), f"Wrong error message: {exc_info.value.errors()}" - def test_add_password(self): - this = CredentialList.from_list_of_dicts([_one_dict, _two_dict] + [_three_dict]) + def test_add_password(self) -> None: + this = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) expected = self.three.to_list_of_dicts() obtained = this.to_list_of_dicts() assert obtained == expected, "List of credentials with added password different than expected" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey(str(ObjectId("333333333333333333333333")))) now_two = self.three @@ -130,11 +132,11 @@ def test_remove(self): assert obtained == expected, "List of credentials with removed credential different than expected" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey(str(ObjectId("55002741d00690878ae9b603")))) - def test_generated(self): + def test_generated(self) -> None: match = self.three.find("222222222222222222222222") assert isinstance(match, Password) assert match.is_generated is False @@ -142,7 +144,7 @@ def test_generated(self): assert isinstance(match, Password) assert match.is_generated is True - def test_external_credential(self): + def test_external_credential(self) -> None: _id = ElementKey(str(ObjectId())) # A SwedenConnectCredential as stored in the database data = {"framework": "SWECONN", "level": "loa3", "credential_id": _id} diff --git a/src/eduid/userdb/tests/test_db.py b/src/eduid/userdb/tests/test_db.py index fef175352..bcbeae231 100644 --- a/src/eduid/userdb/tests/test_db.py +++ b/src/eduid/userdb/tests/test_db.py @@ -5,11 +5,11 @@ import eduid.userdb.db as db from eduid.userdb.fixtures.users import UserFixtures from eduid.userdb.identity import IdentityType -from eduid.userdb.testing import MongoTestCase +from eduid.userdb.testing import MongoTestCase, SetupConfig class TestMongoDB(TestCase): - def test_full_uri(self): + def test_full_uri(self) -> None: # full specified uri uri = "mongodb://db.example.com:1111/testdb" mdb = db.MongoDB(uri, db_name="testdb") @@ -19,21 +19,21 @@ def test_full_uri(self): self.assertEqual(mdb._db_uri, uri) self.assertEqual(mdb._database_name, "testdb") - def test_uri_without_path_component(self): + def test_uri_without_path_component(self) -> None: uri = "mongodb://db.example.com:1111" mdb = db.MongoDB(uri, db_name="testdb") mdb.get_database() self.assertEqual(mdb._db_uri, uri + "/testdb") self.assertEqual(mdb._database_name, "testdb") - def test_uri_without_port(self): + def test_uri_without_port(self) -> None: uri = "mongodb://db.example.com/" mdb = db.MongoDB(uri) self.assertEqual(mdb._db_uri, uri) mdb.get_database("testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://db.example.com/") - def test_uri_with_username_and_password(self): + def test_uri_with_username_and_password(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:1111/testdb" mdb = db.MongoDB(uri, db_name="testdb") conn = mdb.get_connection() @@ -45,7 +45,7 @@ def test_uri_with_username_and_password(self): self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com:1111/testdb") self.assertEqual(mdb.__repr__(), "") - def test_uri_with_replicaset(self): + def test_uri_with_replicaset(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com,db2.example.com:27017,db3.example.com:1234/?replicaSet=rs9" mdb = db.MongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9") @@ -54,53 +54,56 @@ def test_uri_with_replicaset(self): "mongodb://john:s3cr3t@db.example.com,db2.example.com,db3.example.com:1234/testdb?replicaset=rs9", ) - def test_uri_with_options(self): + def test_uri_with_options(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:27017/?ssl=true&replicaSet=rs9" mdb = db.MongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9&tls=true") class TestDB(MongoTestCase): - def setUp(self): + def setUp(self, config: SetupConfig | None = None) -> None: _users = UserFixtures() self._am_users = [_users.new_unverified_user_example, _users.mocked_user_standard_2, _users.new_user_example] - super().setUp(am_users=self._am_users) + if config is None: + config = SetupConfig() + config.am_users = self._am_users + super().setUp(config=config) - def test_db_count(self): + def test_db_count(self) -> None: self.assertEqual(len(self._am_users), self.amdb.db_count()) - def test_db_count_limit(self): + def test_db_count_limit(self) -> None: self.assertEqual(1, self.amdb.db_count(limit=1)) self.assertEqual(2, self.amdb.db_count(limit=2)) - def test_db_count_spec(self): + def test_db_count_spec(self) -> None: self.assertEqual(1, self.amdb.db_count(spec={"_id": ObjectId("012345678901234567890123")})) - def test_get_documents_by_filter_skip(self): + def test_get_documents_by_filter_skip(self) -> None: docs = self.amdb._get_documents_by_filter(spec={}, skip=2) self.assertEqual(1, len(docs)) - def test_get_documents_by_filter_limit(self): + def test_get_documents_by_filter_limit(self) -> None: docs = self.amdb._get_documents_by_filter(spec={}, limit=1) self.assertEqual(1, len(docs)) - def test_get_verified_users_count_NIN(self): + def test_get_verified_users_count_NIN(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.NIN) assert count == 1 - def test_get_verified_users_count_EIDAS(self): + def test_get_verified_users_count_EIDAS(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.EIDAS) assert count == 1 - def test_get_verified_users_count_SVIPE(self): + def test_get_verified_users_count_SVIPE(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.SVIPE) assert count == 1 - def test_get_verified_users_count_None(self): + def test_get_verified_users_count_None(self) -> None: count = self.amdb.get_verified_users_count() assert count == 1 - def test_get_documents_by_aggregate(self): + def test_get_documents_by_aggregate(self) -> None: match = { "eduPersonPrincipalName": "hubba-bubba", } diff --git a/src/eduid/userdb/tests/test_element.py b/src/eduid/userdb/tests/test_element.py index 7f2c83a28..96d13166d 100644 --- a/src/eduid/userdb/tests/test_element.py +++ b/src/eduid/userdb/tests/test_element.py @@ -5,14 +5,14 @@ class TestElements(TestCase): - def test_create_element(self): + def test_create_element(self) -> None: elem = Element(created_by="test") assert elem.created_by == "test" assert isinstance(elem.created_ts, datetime) assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_created_ts(self): + def test_create_element_with_created_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", created_ts=now) @@ -20,7 +20,7 @@ def test_create_element_with_created_ts(self): assert elem.created_ts == now assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_modified_ts(self): + def test_create_element_with_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now) @@ -28,7 +28,7 @@ def test_create_element_with_modified_ts(self): assert elem.modified_ts == now assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_created_and_modified_ts(self): + def test_create_element_with_created_and_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now, created_ts=now) @@ -36,7 +36,7 @@ def test_create_element_with_created_and_modified_ts(self): assert elem.created_ts == now assert elem.modified_ts == now - def test_element_reset_modified_ts(self): + def test_element_reset_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now, created_ts=now) @@ -47,7 +47,7 @@ def test_element_reset_modified_ts(self): class TestVerifiedElements(TestCase): - def test_create_verified_element(self): + def test_create_verified_element(self) -> None: elem = VerifiedElement(created_by="test") assert elem.created_by == "test" @@ -58,7 +58,7 @@ def test_create_verified_element(self): assert elem.verified_by is None assert elem.verified_ts is None - def test_modify_verified_element(self): + def test_modify_verified_element(self) -> None: elem = VerifiedElement(created_by="test") now = datetime.utcnow() @@ -74,7 +74,7 @@ def test_modify_verified_element(self): assert elem.verified_by == "test" assert elem.verified_ts == now - def test_create_full_verified_element(self): + def test_create_full_verified_element(self) -> None: now = datetime.utcnow() elem = VerifiedElement( @@ -91,7 +91,7 @@ def test_create_full_verified_element(self): class TestPrimaryElements(TestCase): - def test_create_primary_element(self): + def test_create_primary_element(self) -> None: elem = PrimaryElement(created_by="test") assert elem.created_by == "test" @@ -104,7 +104,7 @@ def test_create_primary_element(self): assert elem.is_primary is False - def test_modify_primary_element(self): + def test_modify_primary_element(self) -> None: elem = PrimaryElement(created_by="test") now = datetime.utcnow() @@ -124,7 +124,7 @@ def test_modify_primary_element(self): assert elem.is_primary is True - def test_create_full_primary_element(self): + def test_create_full_primary_element(self) -> None: now = datetime.utcnow() elem = PrimaryElement( @@ -147,7 +147,7 @@ def test_create_full_primary_element(self): assert elem.is_primary is True - def test_unverify_primary_element(self): + def test_unverify_primary_element(self) -> None: now = datetime.utcnow() elem = PrimaryElement( diff --git a/src/eduid/userdb/tests/test_event.py b/src/eduid/userdb/tests/test_event.py index 8dfa05196..413a19ad4 100644 --- a/src/eduid/userdb/tests/test_event.py +++ b/src/eduid/userdb/tests/test_event.py @@ -13,6 +13,7 @@ import eduid.userdb.exceptions from eduid.common.testing_base import normalised_data from eduid.userdb import PhoneNumber +from eduid.userdb.element import ElementKey from eduid.userdb.event import EventList from eduid.userdb.tou import ToUEvent @@ -55,42 +56,43 @@ def from_list_of_dicts(cls: type[SomeEventList], items: list[dict[str, Any]]) -> class TestEventList(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = SomeEventList() self.one = SomeEventList.from_list_of_dicts([_one_dict]) self.two = SomeEventList.from_list_of_dicts([_one_dict, _two_dict]) self.three = SomeEventList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): SomeEventList(elements="bad input data") with pytest.raises(ValidationError): SomeEventList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) _one_dict_copy = deepcopy(_one_dict) # Update id to event_id before comparing dicts _one_dict_copy["event_id"] = _one_dict_copy.pop("id") self.assertEqual([_one_dict_copy], self.one.to_list_of_dicts()) - def test_find(self): + def test_find(self) -> None: match = self.one.find(self.one.to_list()[0].key) + assert match self.assertIsInstance(match, ToUEvent) self.assertEqual(match.version, _one_dict["version"]) - def test_add(self): + def test_add(self) -> None: second = self.two.to_list()[-1] self.one.add(second) self.assertEqual(self.one.to_list_of_dicts(), self.two.to_list_of_dicts()) - def test_add_duplicate_key(self): + def test_add_duplicate_key(self) -> None: data = deepcopy(_two_dict) data["version"] = "other version" dup = ToUEvent.from_dict(data) @@ -108,26 +110,26 @@ def test_add_duplicate_key(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_add_event(self): + def test_add_event(self) -> None: third = self.three.to_list_of_dicts()[-1] this = SomeEventList.from_list_of_dicts([_one_dict, _two_dict, third]) self.assertEqual(this.to_list_of_dicts(), self.three.to_list_of_dicts()) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError): - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] - def test_remove(self): + def test_remove(self) -> None: self.three.remove(self.three.to_list()[-1].key) now_two = self.three self.assertEqual(self.two.to_list_of_dicts(), now_two.to_list_of_dicts()) - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("+46709999999") + self.one.remove(ElementKey("+46709999999")) - def test_unknown_event_type(self): + def test_unknown_event_type(self) -> None: e1 = { "event_type": "unknown_event", "id": str(bson.ObjectId()), @@ -151,7 +153,7 @@ def test_unknown_event_type(self): ], ), f"Wrong error message: {exc_info.value.errors()}" - def test_modified_ts_addition(self): + def test_modified_ts_addition(self) -> None: _event_no_modified_ts = { "event_type": "tou_event", "version": "1", @@ -171,10 +173,10 @@ def test_modified_ts_addition(self): else: self.assertIsInstance(event["modified_ts"], datetime) assert event["modified_ts"] == event["created_ts"] - for event in el.to_list(): - self.assertIsInstance(event.modified_ts, datetime) + for event2 in el.to_list(): + self.assertIsInstance(event2.modified_ts, datetime) - def test_update_modified_ts(self): + def test_update_modified_ts(self) -> None: _event_modified_ts = { "event_type": "tou_event", "version": "1", diff --git a/src/eduid/userdb/tests/test_exceptions.py b/src/eduid/userdb/tests/test_exceptions.py index 997489d17..0a01a2999 100644 --- a/src/eduid/userdb/tests/test_exceptions.py +++ b/src/eduid/userdb/tests/test_exceptions.py @@ -6,6 +6,6 @@ class TestEduIDUserDBError(TestCase): - def test_repr(self): + def test_repr(self) -> None: ex = eduid.userdb.exceptions.EduIDUserDBError("test") self.assertIsInstance(str(ex), str) diff --git a/src/eduid/userdb/tests/test_group_management.py b/src/eduid/userdb/tests/test_group_management.py index a20aa6323..ac5050374 100644 --- a/src/eduid/userdb/tests/test_group_management.py +++ b/src/eduid/userdb/tests/test_group_management.py @@ -4,7 +4,7 @@ from eduid.userdb.fixtures.users import UserFixtures from eduid.userdb.group_management import GroupInviteState, GroupManagementInviteStateDB, GroupRole -from eduid.userdb.testing import MongoTestCase +from eduid.userdb.testing import MongoTestCase, SetupConfig from eduid.userdb.user import User __author__ = "lundberg" @@ -13,12 +13,12 @@ class TestResetGroupInviteStateDB(MongoTestCase): user: User - def setUp(self, **kwargs): - super().setUp() + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) self.user = UserFixtures().mocked_user_standard self.invite_state_db = GroupManagementInviteStateDB(self.tmp_db.uri) - def test_invite_state(self): + def test_invite_state(self) -> None: # Member group_scim_id = str(uuid4()) invite_state = GroupInviteState( @@ -31,6 +31,7 @@ def test_invite_state(self): invite = self.invite_state_db.get_state( group_scim_id=group_scim_id, email_address="johnsmith@example.com", role=GroupRole.MEMBER ) + assert invite self.assertEqual(group_scim_id, invite.group_scim_id) self.assertEqual("johnsmith@example.com", invite.email_address) self.assertEqual(GroupRole.MEMBER, invite.role) @@ -47,11 +48,12 @@ def test_invite_state(self): invite = self.invite_state_db.get_state( group_scim_id=group_scim_id, email_address="johnsmith@example.com", role=GroupRole.OWNER ) + assert invite self.assertEqual(group_scim_id, invite.group_scim_id) self.assertEqual("johnsmith@example.com", invite.email_address) self.assertEqual(GroupRole.OWNER, invite.role) - def test_save_duplicate(self): + def test_save_duplicate(self) -> None: group_scim_id = str(uuid4()) invite_state1 = GroupInviteState( group_scim_id=group_scim_id, diff --git a/src/eduid/userdb/tests/test_identities.py b/src/eduid/userdb/tests/test_identities.py index 31da98eeb..2769c9fe1 100644 --- a/src/eduid/userdb/tests/test_identities.py +++ b/src/eduid/userdb/tests/test_identities.py @@ -48,37 +48,39 @@ class TestIdentityList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # Make pytest show full diffs self.empty = IdentityList() self.one = IdentityList.from_list_of_dicts([_one_dict]) self.two = IdentityList.from_list_of_dicts([_one_dict, _two_dict]) self.three = IdentityList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): IdentityList(elements="bad input data") with pytest.raises(ValidationError): IdentityList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("nin") - self.assertIsInstance(match, NinIdentity) + assert match + assert isinstance(match, NinIdentity) self.assertEqual(match.number, "197801011234") assert match.is_verified is True assert match.verified_ts is None - def test_add(self): + def test_add(self) -> None: second = self.two.find("svipe") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -105,8 +107,9 @@ def test_add_duplicate(self) -> None: ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_nin(self): + def test_add_nin(self) -> None: third = self.three.find("eidas") + assert third this = IdentityList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -114,13 +117,14 @@ def test_add_nin(self): assert normalised_data(obtained) == normalised_data(expected), "List with added nin has unexpected data" - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: """Test adding a phone number to the nin-list. Specifically phone, since pydantic can coerce it into a nin since they both have the 'number' field. """ new = PhoneNumber(number="+4612345678") + assert new with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -132,7 +136,7 @@ def test_add_wrong_type(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey("eidas")) now_two = self.three @@ -141,26 +145,27 @@ def test_remove(self): assert normalised_data(obtained) == normalised_data(expected), "List with removed NIN has unexpected data" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey("+46709999999")) class TestIdentity(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = IdentityList() self.one = IdentityList.from_list_of_dicts([_one_dict]) self.two = IdentityList.from_list_of_dicts([_one_dict, _two_dict]) self.three = IdentityList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the Nin. """ nin = self.two.nin + assert nin self.assertEqual(IdentityType.NIN.value, nin.key) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -168,58 +173,67 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(IdentityList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("nin") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("svipe") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("eidas") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("nin") + assert this now = utc_now() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("svipe") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_modify_created_by(self): + def test_modify_created_by(self) -> None: this = self.three.find("eidas") + assert this this.created_by = "unit test" assert this.created_by == "unit test" - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("nin") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_ts_bool(self): + def test_ts_bool(self) -> None: # check that we can't set created_ts or modified_ts to a bool but that we # can read those from db to fix them this = self.three.find("nin") + assert this with self.assertRaises(ValidationError): - this.created_ts = True + this.created_ts = True # type: ignore[assignment] with self.assertRaises(ValidationError): - this.modified_ts = True + this.modified_ts = True # type: ignore[assignment] this_dict = this.to_dict() this_dict["created_ts"] = True this_dict["modified_ts"] = True assert NinIdentity.from_dict(this_dict) is not None - def test_get_missing_proofing_method(self): + def test_get_missing_proofing_method(self) -> None: this = self.three.find("nin") + assert this this.verified_by = "foo" assert this.get_missing_proofing_method() is None this.verified_by = "bankid" diff --git a/src/eduid/userdb/tests/test_idp_user.py b/src/eduid/userdb/tests/test_idp_user.py index b8e0b9ec0..a61310665 100644 --- a/src/eduid/userdb/tests/test_idp_user.py +++ b/src/eduid/userdb/tests/test_idp_user.py @@ -9,7 +9,7 @@ class TestIdpUser(TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.saml_attribute_settings = SAMLAttributeSettings( default_eppn_scope="example.com", @@ -21,7 +21,7 @@ def setUp(self): authn_context_class=EduidAuthnContextClass.PASSWORD_PT, ) - def test_idp_user_to_attributes_all(self): + def test_idp_user_to_attributes_all(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) attributes = idp_user.to_saml_attributes(settings=self.saml_attribute_settings) @@ -37,6 +37,7 @@ def test_idp_user_to_attributes_all(self): continue self.assertIsNotNone(attributes.get(key), f"{key} is unexpectedly None") + assert idp_user.ladok expected = { "c": "se", "cn": "John Smith", @@ -60,7 +61,7 @@ def test_idp_user_to_attributes_all(self): } assert normalised_data(expected) == normalised_data(attributes), f"expected did not match {attributes}" - def test_idp_user_chosen_given_name(self): + def test_idp_user_chosen_given_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) idp_user.given_name = "John Jack" idp_user.chosen_given_name = "Jack" @@ -79,13 +80,13 @@ def test_idp_user_chosen_given_name(self): assert attributes["displayName"] == "John Jack Smith" assert attributes["cn"] == "John Jack Smith" - def test_idp_user_display_name(self): + def test_idp_user_display_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) attributes = idp_user.to_saml_attributes(settings=self.saml_attribute_settings) assert attributes["displayName"] == "John Smith" assert attributes["cn"] == "John Smith" - def test_idp_user_display_name_no_given_name(self): + def test_idp_user_display_name_no_given_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) idp_user.given_name = None idp_user.chosen_given_name = None diff --git a/src/eduid/userdb/tests/test_ladok.py b/src/eduid/userdb/tests/test_ladok.py index 56f48760d..0364da02a 100644 --- a/src/eduid/userdb/tests/test_ladok.py +++ b/src/eduid/userdb/tests/test_ladok.py @@ -10,7 +10,7 @@ class LadokTest(TestCase): def setUp(self) -> None: self.external_uuid = uuid4() - def test_create_ladok(self): + def test_create_ladok(self) -> None: university = University( ladok_name="AB", name=UniversityName(sv="Lärosätesnamn", en="University Name"), created_by="test created_by" ) diff --git a/src/eduid/userdb/tests/test_logs.py b/src/eduid/userdb/tests/test_logs.py index f8145a220..f574892d1 100644 --- a/src/eduid/userdb/tests/test_logs.py +++ b/src/eduid/userdb/tests/test_logs.py @@ -30,15 +30,15 @@ class TestProofingLog(TestCase): user: User - def setUp(self): + def setUp(self) -> None: self.tmp_db = MongoTemporaryInstance.get_instance() self.proofing_log_db = ProofingLog(db_uri=self.tmp_db.uri) self.user = UserFixtures().mocked_user_standard - def tearDown(self): + def tearDown(self) -> None: self.proofing_log_db._drop_whole_collection() - def test_id_proofing_data(self): + def test_id_proofing_data(self) -> None: proofing_element = ProofingLogElement( eppn=self.user.eppn, created_by="test", proofing_method="test", proofing_version="test" ) @@ -52,7 +52,7 @@ def test_id_proofing_data(self): self.assertIsNotNone(hit["created_ts"]) self.assertEqual(hit["proofing_method"], "test") - def test_teleadress_proofing(self): + def test_teleadress_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -80,7 +80,7 @@ def test_teleadress_proofing(self): self.assertEqual(hit["proofing_method"], "TeleAdress") self.assertEqual(hit["proofing_version"], "test") - def test_teleadress_proofing_relation(self): + def test_teleadress_proofing_relation(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -111,7 +111,7 @@ def test_teleadress_proofing_relation(self): self.assertEqual(hit["proofing_method"], "TeleAdress") self.assertEqual(hit["proofing_version"], "test") - def test_letter_proofing(self): + def test_letter_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -140,7 +140,7 @@ def test_letter_proofing(self): self.assertEqual(hit["proofing_method"], "letter") self.assertEqual(hit["proofing_version"], "test") - def test_mail_address_proofing(self): + def test_mail_address_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -165,7 +165,7 @@ def test_mail_address_proofing(self): self.assertEqual(hit["proofing_method"], "e-mail") self.assertEqual(hit["mail_address"], "some_mail_address") - def test_phone_number_proofing(self): + def test_phone_number_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -191,7 +191,7 @@ def test_phone_number_proofing(self): self.assertEqual(hit["phone_number"], "some_phone_number") self.assertEqual(hit["proofing_version"], "test") - def test_se_leg_proofing(self): + def test_se_leg_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -222,7 +222,7 @@ def test_se_leg_proofing(self): self.assertEqual(hit["proofing_method"], "se-leg") self.assertEqual(hit["proofing_version"], "test") - def test_se_leg_proofing_freja(self): + def test_se_leg_proofing_freja(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -254,7 +254,7 @@ def test_se_leg_proofing_freja(self): self.assertEqual(hit["proofing_method"], "se-leg") self.assertEqual(hit["proofing_version"], "test") - def test_ladok_proofing(self): + def test_ladok_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -283,7 +283,7 @@ def test_ladok_proofing(self): self.assertEqual(hit["ladok_name"], "AB") self.assertEqual(hit["proofing_version"], "test") - def test_blank_string_proofing_data(self): + def test_blank_string_proofing_data(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -305,20 +305,20 @@ def test_blank_string_proofing_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_boolean_false_proofing_data(self): + def test_boolean_false_proofing_data(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", "proofing_version": "test", "reference": "reference id", } - proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=0) + proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=0) # type: ignore[arg-type] self.assertTrue(self.proofing_log_db.save(proofing_element)) - proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=False) + proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=False) # type: ignore[arg-type] self.assertTrue(self.proofing_log_db.save(proofing_element)) - def test_deregistered_proofing_data(self): + def test_deregistered_proofing_data(self) -> None: proofing_element = NinNavetProofingLogElement( eppn=self.user.eppn, created_by="test", @@ -343,14 +343,14 @@ def test_deregistered_proofing_data(self): class TestUserChangeLog(TestCase): - def setUp(self): + def setUp(self) -> None: self.tmp_db = MongoTemporaryInstance.get_instance() self.user_log_db = UserChangeLog(db_uri=self.tmp_db.uri) - def tearDown(self): + def tearDown(self) -> None: self.user_log_db._drop_whole_collection() - def _insert_log_fixtures(self): + def _insert_log_fixtures(self) -> None: data_1 = UserChangeLogElement( eppn="hubba-bubba", created_by="test", @@ -375,7 +375,7 @@ def _insert_log_fixtures(self): assert res[0]["eduPersonPrincipalName"] == "hubba-bubba" assert res[1]["eduPersonPrincipalName"] == "hubba-bubba" - def test_get_by_eppn(self): + def test_get_by_eppn(self) -> None: self._insert_log_fixtures() res_1 = self.user_log_db.get_by_eppn("hubba-bubba") diff --git a/src/eduid/userdb/tests/test_mail.py b/src/eduid/userdb/tests/test_mail.py index 7c8a5bc95..d5a42d182 100644 --- a/src/eduid/userdb/tests/test_mail.py +++ b/src/eduid/userdb/tests/test_mail.py @@ -11,6 +11,7 @@ from eduid.common.misc.timeutil import utc_now from eduid.common.testing_base import normalised_data from eduid.userdb import PhoneNumber +from eduid.userdb.element import ElementKey from eduid.userdb.mail import MailAddress, MailAddressList __author__ = "ft" @@ -35,37 +36,39 @@ class TestMailAddressList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None self.empty = MailAddressList() self.one = MailAddressList.from_list_of_dicts([_one_dict]) self.two = MailAddressList.from_list_of_dicts([_one_dict, _two_dict]) self.three = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): MailAddressList(elements="bad input data") with pytest.raises(ValidationError): MailAddressList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_empty_to_list_of_dicts(self): + def test_empty_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("ft@one.example.org") + assert match self.assertIsInstance(match, MailAddress) self.assertEqual(match.email, "ft@one.example.org") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("ft@two.example.org") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -73,8 +76,10 @@ def test_add(self): assert obtained == expected, "Wrong data after adding mail address to list" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: + assert self.two.primary dup = self.two.find(self.two.primary.email) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -89,8 +94,9 @@ def test_add_duplicate(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_mailaddress(self): + def test_add_mailaddress(self) -> None: third = self.three.find("ft@three.example.org") + assert third this = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -98,20 +104,20 @@ def test_add_mailaddress(self): assert obtained == expected, "Wrong data in mail address list" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = eduid.userdb.mail.address_from_dict( {"email": "ft@primary.example.org", "verified": True, "primary": True} ) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError): - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] - def test_remove(self): - self.three.remove("ft@three.example.org") + def test_remove(self) -> None: + self.three.remove(ElementKey("ft@three.example.org")) now_two = self.three expected = self.two.to_list_of_dicts() @@ -119,51 +125,58 @@ def test_remove(self): assert obtained == expected, "Wrong data after removing email from list" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("foo@no-such-address.example.org") + self.one.remove(ElementKey("foo@no-such-address.example.org")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with pytest.raises( eduid.userdb.element.PrimaryElementViolation, match="Removing the primary element is not allowed" ): self.two.remove(self.two.primary.key) - def test_remove_primary_single(self): - self.one.remove(self.one.primary.email) + def test_remove_primary_single(self) -> None: + assert self.one.primary + self.one.remove(ElementKey(self.one.primary.email)) now_empty = self.one self.assertEqual([], now_empty.to_list()) - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.email, "ft@one.example.org") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary - self.one.set_primary(match.email) + assert match + self.one.set_primary(ElementKey(match.email)) match = self.two.primary - self.two.set_primary(match.email) + assert match + self.two.set_primary(ElementKey(match.email)) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.set_primary("foo@no-such-address.example.org") + self.one.set_primary(ElementKey("foo@no-such-address.example.org")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.three.set_primary("ft@three.example.org") + self.three.set_primary(ElementKey("ft@three.example.org")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.email, "ft@one.example.org") - self.two.set_primary("ft@two.example.org") + self.two.set_primary(ElementKey("ft@two.example.org")) updated = self.two.primary + assert updated self.assertEqual(updated.email, "ft@two.example.org") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -171,7 +184,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): MailAddressList.from_list_of_dicts([one, two]) - def test_bad_input_unverified_primary(self): + def test_bad_input_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -179,20 +192,21 @@ def test_bad_input_unverified_primary(self): class TestMailAddress(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = MailAddressList() self.one = MailAddressList.from_list_of_dicts([_one_dict]) self.two = MailAddressList.from_list_of_dicts([_one_dict, _two_dict]) self.three = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the MailAddress. """ address = self.two.primary + assert address self.assertEqual(address.key, address.email) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -204,7 +218,7 @@ def test_parse_cycle(self): assert cycled == expected - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -219,7 +233,7 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_bad_input_type(self): + def test_bad_input_type(self) -> None: one = copy.deepcopy(_one_dict) one["email"] = False with pytest.raises(ValidationError) as exc_info: @@ -237,46 +251,54 @@ def test_bad_input_type(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_verified_ts(self): + def test_verified_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_ts = utc_now() self.assertIsInstance(this.verified_ts, datetime.datetime) - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_ts = utc_now() - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_uppercase_email_address(self): + def test_uppercase_email_address(self) -> None: address = "UPPERCASE@example.com" mail_address = MailAddress(email=address) self.assertEqual(address.lower(), mail_address.email) diff --git a/src/eduid/userdb/tests/test_nin.py b/src/eduid/userdb/tests/test_nin.py index a327be99e..8e0967917 100644 --- a/src/eduid/userdb/tests/test_nin.py +++ b/src/eduid/userdb/tests/test_nin.py @@ -34,37 +34,39 @@ class TestNinList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # Make pytest show full diffs self.empty = NinList() self.one = NinList.from_list_of_dicts([_one_dict]) self.two = NinList.from_list_of_dicts([_one_dict, _two_dict]) self.three = NinList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): NinList(elements="bad input data") with pytest.raises(ValidationError): NinList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("197801011234") + assert match self.assertIsInstance(match, Nin) self.assertEqual(match.number, "197801011234") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("197802022345") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -91,8 +93,9 @@ def test_add_duplicate(self) -> None: ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_nin(self): + def test_add_nin(self) -> None: third = self.three.find("197803033456") + assert third this = NinList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -100,18 +103,18 @@ def test_add_nin(self): assert obtained == expected, "List with added nin has unexpected data" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = eduid.userdb.nin.nin_from_dict({"number": "+46700000009", "verified": True, "primary": True}) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: """Test adding a phone number to the nin-list. Specifically phone, since pydantic can coerce it into a nin since they both have the 'number' field. """ new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -123,7 +126,7 @@ def test_add_wrong_type(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey("197803033456")) now_two = self.three @@ -132,49 +135,56 @@ def test_remove(self): assert obtained == expected, "List with removed NIN has unexpected data" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey("+46709999999")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.two.remove(self.two.primary.key) - def test_remove_primary_single(self): + def test_remove_primary_single(self) -> None: + assert self.one.primary self.one.remove(self.one.primary.key) now_empty = self.one self.assertEqual([], now_empty.to_list()) - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.number, "197801011234") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary + assert match self.one.set_primary(match.key) match = self.two.primary + assert match self.two.set_primary(match.key) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.set_primary(ElementKey("+46709999999")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.three.set_primary(ElementKey("197803033456")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.number, "197801011234") self.two.set_primary(ElementKey("197802022345")) updated = self.two.primary + assert updated self.assertEqual(updated.number, "197802022345") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -182,7 +192,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): NinList.from_list_of_dicts([one, two]) - def test_bad_input_unverified_primary(self): + def test_bad_input_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -190,20 +200,21 @@ def test_bad_input_unverified_primary(self): class TestNin(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = NinList() self.one = NinList.from_list_of_dicts([_one_dict]) self.two = NinList.from_list_of_dicts([_one_dict, _two_dict]) self.three = NinList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the Nin. """ address = self.two.primary + assert address self.assertEqual(address.key, address.number) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -211,44 +222,52 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(NinList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("197803033456") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("197803033456") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("197803033456") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("197803033456") + assert this now = datetime.datetime.utcnow() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("197803033456") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_modify_created_by(self): + def test_modify_created_by(self) -> None: this = self.three.find("197803033456") + assert this this.created_by = "unit test" assert this.created_by == "unit test" - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("197803033456") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_orcid.py b/src/eduid/userdb/tests/test_orcid.py index 596f70bbf..3069fb846 100644 --- a/src/eduid/userdb/tests/test_orcid.py +++ b/src/eduid/userdb/tests/test_orcid.py @@ -37,7 +37,8 @@ class TestOrcid(unittest.TestCase): maxDiff = None - def test_id_token(self): + def test_id_token(self) -> None: + assert isinstance(token_response["id_token"], dict) id_token_data = token_response["id_token"] id_token_data["created_by"] = "test" id_token_1 = OidcIdToken.from_dict(id_token_data) @@ -65,9 +66,10 @@ def test_id_token(self): assert dict_1 == dict_2 with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - OidcIdToken.from_dict(None) + OidcIdToken.from_dict(None) # type: ignore[arg-type] - def test_oidc_authz(self): + def test_oidc_authz(self) -> None: + assert isinstance(token_response["id_token"], dict) id_token_data = token_response["id_token"] id_token_data["created_by"] = "test" id_token = OidcIdToken.from_dict(token_response["id_token"]) @@ -96,9 +98,10 @@ def test_oidc_authz(self): assert dict_1 == dict_2 with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - OidcAuthorization.from_dict(None) + OidcAuthorization.from_dict(None) # type: ignore[arg-type] - def test_orcid(self): + def test_orcid(self) -> None: + assert isinstance(token_response["id_token"], dict) token_response["id_token"]["created_by"] = "test" token_response["created_by"] = "test" oidc_authz = OidcAuthorization.from_dict(token_response) @@ -135,4 +138,4 @@ def test_orcid(self): ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" with pytest.raises(eduid.userdb.exceptions.UserDBValueError): - Orcid.from_dict(None) + Orcid.from_dict(None) # type: ignore[arg-type] diff --git a/src/eduid/userdb/tests/test_password.py b/src/eduid/userdb/tests/test_password.py index 867bfe602..48e267d88 100644 --- a/src/eduid/userdb/tests/test_password.py +++ b/src/eduid/userdb/tests/test_password.py @@ -3,7 +3,7 @@ from bson.objectid import ObjectId -from eduid.userdb.credentials import CredentialList +from eduid.userdb.credentials import CredentialList, Password __author__ = "lundberg" @@ -17,20 +17,22 @@ class TestPassword(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the Password. """ password = self.one.find(str(ObjectId("55002741d00690878ae9b600"))) + assert password + assert isinstance(password, Password) self.assertEqual(password.key, password.credential_id) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -38,11 +40,13 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(str(ObjectId("55002741d00690878ae9b600"))) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(str(ObjectId("55002741d00690878ae9b600"))) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_phone.py b/src/eduid/userdb/tests/test_phone.py index ae29bc7a5..6314b86ec 100644 --- a/src/eduid/userdb/tests/test_phone.py +++ b/src/eduid/userdb/tests/test_phone.py @@ -10,6 +10,7 @@ from eduid.common.misc.timeutil import utc_now from eduid.common.testing_base import normalised_data from eduid.userdb import MailAddress +from eduid.userdb.element import ElementKey from eduid.userdb.phone import PhoneNumber, PhoneNumberList __author__ = "ft" @@ -40,26 +41,26 @@ class TestPhoneNumberList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = PhoneNumberList() self.one = PhoneNumberList.from_list_of_dicts([_one_dict]) self.two = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict]) self.three = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) self.four = PhoneNumberList.from_list_of_dicts([_three_dict, _four_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): PhoneNumberList(elements="bad input data") with pytest.raises(ValidationError): PhoneNumberList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: assert self.empty.to_list_of_dicts() == [] assert isinstance(self.one.to_list(), list) assert len(self.one.to_list()) == 1 - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: assert self.empty.to_list_of_dicts() == [] one_dict_list = self.one.to_list_of_dicts() @@ -67,23 +68,27 @@ def test_to_list_of_dicts(self): assert one_dict_list == expected - def test_find(self): + def test_find(self) -> None: match = self.one.find("+46700000001") + assert match self.assertIsInstance(match, PhoneNumber) self.assertEqual(match.number, "+46700000001") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("+46700000002") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() got = self.one.to_list_of_dicts() assert got == expected, "Adding a phone number to a list results in wrong data" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: + assert self.two.primary dup = self.two.find(self.two.primary.number) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -98,8 +103,9 @@ def test_add_duplicate(self): ] ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_phonenumber(self): + def test_add_phonenumber(self) -> None: third = self.three.find("+46700000003") + assert third this = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -107,15 +113,15 @@ def test_add_phonenumber(self): assert got == expected, "Phone number list contains wrong data" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = PhoneNumber(number="+46700000009", is_verified=True, is_primary=True) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = MailAddress(email="ft@example.org") with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -127,8 +133,8 @@ def test_add_wrong_type(self): ] ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): - self.three.remove("+46700000003") + def test_remove(self) -> None: + self.three.remove(ElementKey("+46700000003")) now_two = self.three expected = self.two.to_list_of_dicts() @@ -136,41 +142,45 @@ def test_remove(self): assert got == expected, "Phone list has wrong data after removing phone" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("+46709999999") + self.one.remove(ElementKey("+46709999999")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.two.remove(self.two.primary.number) + self.two.remove(ElementKey(self.two.primary.number)) - def test_remove_primary_single(self): - self.one.remove(self.one.primary.number) + def test_remove_primary_single(self) -> None: + assert self.one.primary + self.one.remove(ElementKey(self.one.primary.number)) now_empty = self.one assert now_empty.to_list() == [] - def test_remove_all_mix(self): + def test_remove_all_mix(self) -> None: # First, remove all numbers except the primary for mobile in self.three.to_list(): if not mobile.is_primary: self.three.remove(mobile.key) # Now, remove the primary number (which can't be removed until it is the last element) + assert self.three.primary self.three.remove(self.three.primary.key) assert self.three.to_list() == [] - def test_remove_all_no_verified(self): + def test_remove_all_no_verified(self) -> None: verified = self.four.verified if verified: for mobile in verified: if not mobile.is_primary: - self.four.remove(mobile.number) - self.four.remove(self.four.primary.number) + self.four.remove(ElementKey(mobile.number)) + assert self.four.primary + self.four.remove(ElementKey(self.four.primary.number)) for mobile in self.four.to_list(): - self.four.remove(mobile.number) + self.four.remove(ElementKey(mobile.number)) self.assertEqual([], self.four.to_list()) - def test_unverify_all(self): + def test_unverify_all(self) -> None: verified = self.three.verified for mobile in verified: @@ -180,36 +190,41 @@ def test_unverify_all(self): verified_now = self.three.verified assert verified_now == [] - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.number, "+46700000001") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary - self.one.set_primary(match.number) + assert match + self.one.set_primary(ElementKey(match.number)) match = self.two.primary - self.two.set_primary(match.number) + assert match + self.two.set_primary(ElementKey(match.number)) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.set_primary("+46709999999") + self.one.set_primary(ElementKey("+46709999999")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.three.set_primary("+46700000003") + self.three.set_primary(ElementKey("+46700000003")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.number, "+46700000001") - self.two.set_primary("+46700000002") + self.two.set_primary(ElementKey("+46700000002")) updated = self.two.primary + assert updated self.assertEqual(updated.number, "+46700000002") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -217,7 +232,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): PhoneNumberList.from_list_of_dicts([one, two]) - def test_unverified_primary(self): + def test_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -225,22 +240,23 @@ def test_unverified_primary(self): class TestPhoneNumber(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = PhoneNumberList() self.one = PhoneNumberList.from_list_of_dicts([_one_dict]) self.two = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict]) self.three = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the PhoneNumber. """ address = self.two.primary + assert address self.assertEqual(address.key, address.number) - def test_create_phone_number(self): - one = copy.deepcopy(_one_dict) - one = PhoneNumber.from_dict(one) + def test_create_phone_number(self) -> None: + one_copy = copy.deepcopy(_one_dict) + one = PhoneNumber.from_dict(one_copy) # remove added timestamp one_dict = one.to_dict() @@ -248,7 +264,7 @@ def test_create_phone_number(self): assert _one_dict["verified"] == one_dict["verified"], "Created phone has wrong is_verified" assert _one_dict["number"] == one_dict["number"], "Created phone has wrong number" - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -256,7 +272,7 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(PhoneNumberList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -271,39 +287,46 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("+46700000003") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("+46700000003") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("+46700000003") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("+46700000003") + assert this now = utc_now() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("+46700000003") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("+46700000003") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_profile.py b/src/eduid/userdb/tests/test_profile.py index 7af2553f0..8b8d250e0 100644 --- a/src/eduid/userdb/tests/test_profile.py +++ b/src/eduid/userdb/tests/test_profile.py @@ -13,7 +13,7 @@ class ProfileTest(TestCase): - def test_create_profile(self): + def test_create_profile(self) -> None: profile = Profile( owner="test owner", profile_schema="test schema", @@ -28,7 +28,7 @@ def test_create_profile(self): self.assertIn(key, profile.profile_data) self.assertEqual(value, profile.profile_data[key]) - def test_profile_list(self): + def test_profile_list(self) -> None: profile = Profile( owner="test owner 1", profile_schema="test schema", @@ -48,12 +48,12 @@ def test_profile_list(self): self.assertIsNotNone(profile_list.find("test owner 1")) self.assertIsNotNone(profile_list.find("test owner 2")) - def test_empty_profile_list(self): + def test_empty_profile_list(self) -> None: profile_list = ProfileList() self.assertIsNotNone(profile_list) self.assertEqual(profile_list.count, 0) - def test_profile_list_owner_conflict(self): + def test_profile_list_owner_conflict(self) -> None: profile = Profile( owner="test owner 1", profile_schema="test schema", diff --git a/src/eduid/userdb/tests/test_proofing.py b/src/eduid/userdb/tests/test_proofing.py index c40cdd2e0..470bfc302 100644 --- a/src/eduid/userdb/tests/test_proofing.py +++ b/src/eduid/userdb/tests/test_proofing.py @@ -28,7 +28,7 @@ class ProofingStateTest(TestCase): - def _test_create_letterproofingstate(self, state: LetterProofingState, nin_expected_keys: list[str]): + def _test_create_letterproofingstate(self, state: LetterProofingState, nin_expected_keys: list[str]) -> None: """ { 'eppn': 'foob-arra', @@ -73,7 +73,7 @@ def _test_create_letterproofingstate(self, state: LetterProofingState, nin_expec sorted(_proofing_letter_expected_keys), ) - def test_create_letterproofingstate_with_ninproofingelement_from_dict(self): + def test_create_letterproofingstate_with_ninproofingelement_from_dict(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -102,7 +102,7 @@ def test_create_letterproofingstate_with_ninproofingelement_from_dict(self): self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_created_ts(self): + def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_created_ts(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -130,7 +130,7 @@ def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_creat self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_letterproofingstate(self): + def test_create_letterproofingstate(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -155,7 +155,7 @@ def test_create_letterproofingstate(self): self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_oidcproofingstate(self): + def test_create_oidcproofingstate(self) -> None: """ { 'eduPersonPrincipalName': 'foob-arra', @@ -183,7 +183,7 @@ def test_create_oidcproofingstate(self): ["_id", "eduPersonPrincipalName", "modified_ts", "nin", "nonce", "state", "token"], ) - def test_proofing_state_expiration(self): + def test_proofing_state_expiration(self) -> None: state = ProofingState(id=None, eppn=EPPN, modified_ts=datetime.now(tz=None)) self.assertFalse(state.is_expired(1)) diff --git a/src/eduid/userdb/tests/test_resetpw.py b/src/eduid/userdb/tests/test_resetpw.py index cba508ab3..635b63efa 100644 --- a/src/eduid/userdb/tests/test_resetpw.py +++ b/src/eduid/userdb/tests/test_resetpw.py @@ -1,17 +1,20 @@ from datetime import timedelta from eduid.userdb.reset_password import ResetPasswordEmailAndPhoneState, ResetPasswordEmailState, ResetPasswordStateDB -from eduid.userdb.testing import MongoTestCase +from eduid.userdb.reset_password.element import CodeElement +from eduid.userdb.testing import MongoTestCase, SetupConfig class TestResetPasswordStateDB(MongoTestCase): - def setUp(self): - super().setUp() + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) self.resetpw_db = ResetPasswordStateDB(self.tmp_db.uri, "eduid_reset_password") - def test_email_state(self): + def test_email_state(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) self.resetpw_db.save(email_state, is_in_database=False) @@ -25,9 +28,11 @@ def test_email_state(self): self.assertTrue(state.email_code.is_expired(timedelta(0))) self.assertFalse(state.email_code.is_expired(timedelta(1))) - def test_email_state_get_by_code(self): + def test_email_state_get_by_code(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) self.resetpw_db.save(email_state, is_in_database=False) @@ -39,9 +44,11 @@ def test_email_state_get_by_code(self): self.assertEqual(state.eppn, "hubba-bubba") self.assertEqual(state.generated_password, False) - def test_email_state_generated_pw(self): + def test_email_state_generated_pw(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) email_state.generated_password = True @@ -52,9 +59,11 @@ def test_email_state_generated_pw(self): self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.generated_password, True) - def test_email_state_extra_security(self): + def test_email_state_extra_security(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) email_state.extra_security = {"phone_numbers": [{"number": "+99999999999", "primary": True, "verified": True}]} @@ -66,19 +75,20 @@ def test_email_state_extra_security(self): self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.extra_security["phone_numbers"][0]["number"], "+99999999999") - def test_email_and_phone_state(self): + def test_email_and_phone_state(self) -> None: email_state = ResetPasswordEmailAndPhoneState( eppn="hubba-bubba", email_address="johnsmith@example.com", - email_code="dummy-code", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), phone_number="+99999999999", - phone_code="dummy-phone-code", + phone_code=CodeElement.parse(application="test", code_or_element="dummy-phone-code"), ) self.resetpw_db.save(email_state, is_in_database=False) state = self.resetpw_db.get_state_by_eppn("hubba-bubba") assert state is not None + assert isinstance(state, ResetPasswordEmailAndPhoneState) self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.email_code.code, "dummy-code") self.assertEqual(state.phone_number, "+99999999999") diff --git a/src/eduid/userdb/tests/test_signup_invite.py b/src/eduid/userdb/tests/test_signup_invite.py index 5df717f11..684c439a8 100644 --- a/src/eduid/userdb/tests/test_signup_invite.py +++ b/src/eduid/userdb/tests/test_signup_invite.py @@ -6,7 +6,7 @@ class TestSignupInvite(TestCase): - def test_scim_invite(self): + def test_scim_invite(self) -> None: invite = Invite( invite_type=InviteType.SCIM, invite_reference=SCIMReference(data_owner="test_data_owner", scim_id=uuid4()), diff --git a/src/eduid/userdb/tests/test_signup_user.py b/src/eduid/userdb/tests/test_signup_user.py index 322a93efd..3483a243b 100644 --- a/src/eduid/userdb/tests/test_signup_user.py +++ b/src/eduid/userdb/tests/test_signup_user.py @@ -7,23 +7,23 @@ class TestSignupUser(TestCase): - def setUp(self): + def setUp(self) -> None: self.user = UserFixtures().new_signup_user_example self.user_data = self.user.to_dict() - def test_proper_user(self): + def test_proper_user(self) -> None: self.assertEqual(self.user.user_id, self.user_data["_id"]) self.assertEqual(self.user.eppn, self.user_data["eduPersonPrincipalName"]) - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: user = SignupUser(user_id=self.user.user_id, eppn=self.user.eppn) self.assertEqual(user.user_id, self.user.user_id) self.assertEqual(user.eppn, self.user.eppn) - def test_missing_id(self): + def test_missing_id(self) -> None: user = SignupUser(eppn=self.user.eppn) self.assertNotEqual(user.user_id, self.user.user_id) - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: with self.assertRaises(ValidationError): - SignupUser(user_id=self.user.user_id) + SignupUser(user_id=self.user.user_id) # type: ignore[call-arg] diff --git a/src/eduid/userdb/tests/test_support_models.py b/src/eduid/userdb/tests/test_support_models.py index 7bfeae679..1b6b21062 100644 --- a/src/eduid/userdb/tests/test_support_models.py +++ b/src/eduid/userdb/tests/test_support_models.py @@ -5,24 +5,24 @@ class TestSupportUsers(TestCase): - def setUp(self): + def setUp(self) -> None: self.users = UserFixtures() - def test_support_user(self): + def test_support_user(self) -> None: user = models.SupportUserFilter(self.users.new_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_signup_user(self): + def test_support_signup_user(self) -> None: user = models.SupportSignupUserFilter(self.users.new_signup_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_completed_signup_user(self): + def test_support_completed_signup_user(self) -> None: user = models.SupportSignupUserFilter(self.users.new_completed_signup_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) @@ -32,12 +32,14 @@ def test_support_completed_signup_user(self): The assertion is here only for good measure to make sure that the right example data is being used. """ - self.assertTrue(len(user.get("pending_mail_address")) == 0) + pending = user.get("pending_mail_address") + assert pending is not None + self.assertTrue(len(pending) == 0) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_user_authn_info(self): + def test_support_user_authn_info(self) -> None: raw_data = { "_id": "5c5b027c20d6b6000db13187", "fail_count": {"201902": 1, "201903": 0}, diff --git a/src/eduid/userdb/tests/test_tou.py b/src/eduid/userdb/tests/test_tou.py index b70a34292..428738135 100644 --- a/src/eduid/userdb/tests/test_tou.py +++ b/src/eduid/userdb/tests/test_tou.py @@ -8,6 +8,7 @@ from eduid.userdb.actions.tou import ToUUser from eduid.userdb.credentials import CredentialList +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.event import Event, EventList from eduid.userdb.exceptions import UserMissingData from eduid.userdb.fixtures.users import UserFixtures @@ -46,20 +47,20 @@ class TestToUEvent(TestCase): - def setUp(self): - self.empty = EventList() + def setUp(self) -> None: + self.empty: EventList = EventList() self.one = ToUList.from_list_of_dicts([_one_dict]) self.two = ToUList.from_list_of_dicts([_one_dict, _two_dict]) self.three = ToUList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by ElementList) works for the ToUEvent. """ event = self.two.to_list()[0] self.assertEqual(event.key, event.event_id) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -68,22 +69,24 @@ def test_parse_cycle(self): new_list = ToUList.from_list_of_dicts(this_dict) assert new_list.to_list_of_dicts() == this.to_list_of_dicts() - def test_created_by(self): + def test_created_by(self) -> None: this = Event.from_dict(dict(created_by=None, event_type="test_event")) this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_event_type(self): + def test_event_type(self) -> None: this = self.one.to_list()[0] self.assertEqual(this.event_type, "tou_event") - def test_reaccept_tou(self): + def test_reaccept_tou(self) -> None: three_years = timedelta(days=3 * 365) one_day = timedelta(days=1) # set modified_ts to both sides of three years ago _two_dict["modified_ts"] = utc_now() - three_years + one_day _three_dict["modified_ts"] = utc_now() - three_years - one_day + assert isinstance(_two_dict["modified_ts"], datetime) assert _two_dict["modified_ts"] + three_years > utc_now() + assert isinstance(_three_dict["modified_ts"], datetime) assert _three_dict["modified_ts"] + three_years < utc_now() # check if the TOU needs to be accepted with an interval of three years @@ -95,16 +98,16 @@ def test_reaccept_tou(self): class TestTouUser(TestCase): user: User - def setUp(self): + def setUp(self) -> None: self.user = UserFixtures().new_user_example - def test_proper_user(self): + def test_proper_user(self) -> None: userdata = self.user.to_dict() userdata["tou"] = [copy.deepcopy(_one_dict)] user = ToUUser.from_dict(data=userdata) self.assertEqual(user.tou.to_list_of_dicts()[0]["version"], "1") - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) userdata = self.user.to_dict() @@ -114,31 +117,31 @@ def test_proper_new_user(self): user = ToUUser(user_id=userid, eppn=eppn, tou=tou, credentials=passwords) self.assertEqual(user.tou.to_list_of_dicts()[0]["version"], "1") - def test_proper_new_user_no_id(self): + def test_proper_new_user_no_id(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList(elements=[ToUEvent.from_dict(one)]) userdata = self.user.to_dict() passwords = CredentialList.from_list_of_dicts(userdata["passwords"]) with self.assertRaises(ValidationError): - ToUUser(tou=tou, credentials=passwords) + ToUUser(tou=tou, credentials=passwords) # type: ignore[call-arg] - def test_proper_new_user_no_eppn(self): + def test_proper_new_user_no_eppn(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) userdata = self.user.to_dict() userid = userdata.pop("_id") passwords = CredentialList.from_list_of_dicts(userdata["passwords"]) with self.assertRaises(ValidationError): - ToUUser(user_id=userid, tou=tou, credentials=passwords) + ToUUser(user_id=userid, tou=tou, credentials=passwords) # type: ignore[call-arg] - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) with self.assertRaises(UserMissingData): - ToUUser.from_dict(data=dict(tou=tou, userid=self.user.user_id)) + ToUUser.from_dict(data=TUserDbDocument({"tou": tou, "userid": self.user.user_id})) - def test_missing_userid(self): + def test_missing_userid(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUEvent.from_dict(one) with self.assertRaises(UserMissingData): - ToUUser.from_dict(data=dict(tou=[tou], eppn=self.user.eppn)) + ToUUser.from_dict(data=TUserDbDocument({"tou": [tou], "eppn": self.user.eppn})) diff --git a/src/eduid/userdb/tests/test_u2f.py b/src/eduid/userdb/tests/test_u2f.py index 084698888..660c0985d 100644 --- a/src/eduid/userdb/tests/test_u2f.py +++ b/src/eduid/userdb/tests/test_u2f.py @@ -38,22 +38,24 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]) -> str: return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["public_key"].encode("utf-8")).hexdigest() class TestU2F(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the credential. """ this = self.one.find(_keyid(_one_dict)) + assert this + assert isinstance(this, U2F) self.assertEqual( this.key, _keyid( @@ -64,7 +66,7 @@ def test_key(self): ), ) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -72,7 +74,7 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -86,17 +88,20 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_proofing_method(self): + def test_proofing_method(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI self.assertEqual(this.proofing_method, CredentialProofingMethod.SWAMID_AL2_MFA_HI) this.proofing_method = CredentialProofingMethod.SWAMID_AL3_MFA @@ -104,8 +109,9 @@ def test_proofing_method(self): this.proofing_method = None self.assertEqual(this.proofing_method, None) - def test_proofing_version(self): + def test_proofing_version(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_version = "TEST" self.assertEqual(this.proofing_version, "TEST") this.proofing_version = "TEST2" @@ -113,10 +119,12 @@ def test_proofing_version(self): this.proofing_version = None self.assertEqual(this.proofing_version, None) - def test_swamid_al2_hi_to_swamid_al3_migration(self): + def test_swamid_al2_hi_to_swamid_al3_migration(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI this.is_verified = True load_save_cred_list = CredentialList.from_list_of_dicts([this.to_dict()]) load_save_cred = load_save_cred_list.find(_keyid(_three_dict)) + assert load_save_cred self.assertEqual(load_save_cred.proofing_method, CredentialProofingMethod.SWAMID_AL3_MFA) diff --git a/src/eduid/userdb/tests/test_user.py b/src/eduid/userdb/tests/test_user.py index 975a9b36d..898fed235 100644 --- a/src/eduid/userdb/tests/test_user.py +++ b/src/eduid/userdb/tests/test_user.py @@ -8,6 +8,7 @@ from eduid.userdb import NinIdentity, OidcAuthorization, OidcIdToken, Orcid from eduid.userdb.credentials import U2F, CredentialList, CredentialProofingMethod, Password +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.exceptions import EduIDUserDBError, UserHasNotCompletedSignup, UserIsRevoked from eduid.userdb.fixtures.identity import verified_nin_identity from eduid.userdb.fixtures.users import UserFixtures @@ -22,106 +23,110 @@ __author__ = "ft" -def _keyid(kh): +def _keyid(kh: str) -> str: return "sha256:" + sha256(kh.encode("utf-8")).hexdigest() class TestNewUser(unittest.TestCase): - def setUp(self): - self.data1 = { - "_id": ObjectId("547357c3d00690878ae9b620"), - "eduPersonPrincipalName": "guvat-nalif", - "givenName": "User", - "chosen_given_name": "User", - "legal_name": "User One", - "mail": "user@example.net", - "mailAliases": [ - { - "added_timestamp": datetime.fromisoformat("2014-12-18T11:25:19.804000"), - "email": "user@example.net", - "verified": True, - "primary": True, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-11-24T16:22:49.188000"), - "credential_id": "54735b588a7d2a2c4ec3e7d0", - "salt": "$NDNv1H1$315d7$32$32$", - "created_by": "dashboard", - "is_generated": False, - } - ], - "identities": [verified_nin_identity.to_dict()], - "subject": "physical person", - "surname": "One", - "eduPersonEntitlement": ["http://foo.example.org"], - "preferredLanguage": "en", - } + def setUp(self) -> None: + self.data1 = TUserDbDocument( + { + "_id": ObjectId("547357c3d00690878ae9b620"), + "eduPersonPrincipalName": "guvat-nalif", + "givenName": "User", + "chosen_given_name": "User", + "legal_name": "User One", + "mail": "user@example.net", + "mailAliases": [ + { + "added_timestamp": datetime.fromisoformat("2014-12-18T11:25:19.804000"), + "email": "user@example.net", + "verified": True, + "primary": True, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-11-24T16:22:49.188000"), + "credential_id": "54735b588a7d2a2c4ec3e7d0", + "salt": "$NDNv1H1$315d7$32$32$", + "created_by": "dashboard", + "is_generated": False, + } + ], + "identities": [verified_nin_identity.to_dict()], + "subject": "physical person", + "surname": "One", + "eduPersonEntitlement": ["http://foo.example.org"], + "preferredLanguage": "en", + } + ) - self.data2 = { - "_id": ObjectId("549190b5d00690878ae9b622"), - "displayName": "Some \xf6ne", - "eduPersonPrincipalName": "birub-gagoz", - "givenName": "Some", - "mail": "some.one@gmail.com", - "mailAliases": [ - {"email": "someone+test1@gmail.com", "verified": True}, - { - "added_timestamp": datetime.fromisoformat("2014-12-17T14:35:14.728000"), - "email": "some.one@gmail.com", - "verified": True, - }, - ], - "phone": [ - { - "created_ts": datetime.fromisoformat("2014-12-18T09:11:35.078000"), - "number": "+46702222222", - "primary": True, - "verified": True, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2015-02-11T13:58:42.327000"), - "id": ObjectId("54db60128a7d2a26e8690cda"), - "salt": "$NDNv1H1$db011fc$32$32$", - "is_generated": False, - "source": "dashboard", - }, - { - "version": "U2F_V2", - "app_id": "unit test", - "keyhandle": "U2F SWAMID AL3", - "public_key": "foo", - "verified": True, - "proofing_method": CredentialProofingMethod.SWAMID_AL3_MFA, - "proofing_version": "testing", - }, - ], - "profiles": [ - { - "created_by": "test application", - "created_ts": datetime.fromisoformat("2020-02-04T17:42:33.696751"), - "owner": "test owner 1", - "schema": "test schema", - "profile_data": { - "a_string": "I am a string", - "an_int": 3, - "a_list": ["eins", 2, "drei"], - "a_map": {"some": "data"}, + self.data2 = TUserDbDocument( + { + "_id": ObjectId("549190b5d00690878ae9b622"), + "displayName": "Some \xf6ne", + "eduPersonPrincipalName": "birub-gagoz", + "givenName": "Some", + "mail": "some.one@gmail.com", + "mailAliases": [ + {"email": "someone+test1@gmail.com", "verified": True}, + { + "added_timestamp": datetime.fromisoformat("2014-12-17T14:35:14.728000"), + "email": "some.one@gmail.com", + "verified": True, }, - } - ], - "preferredLanguage": "sv", - "surname": "\xf6ne", - "subject": "physical person", - } + ], + "phone": [ + { + "created_ts": datetime.fromisoformat("2014-12-18T09:11:35.078000"), + "number": "+46702222222", + "primary": True, + "verified": True, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2015-02-11T13:58:42.327000"), + "id": ObjectId("54db60128a7d2a26e8690cda"), + "salt": "$NDNv1H1$db011fc$32$32$", + "is_generated": False, + "source": "dashboard", + }, + { + "version": "U2F_V2", + "app_id": "unit test", + "keyhandle": "U2F SWAMID AL3", + "public_key": "foo", + "verified": True, + "proofing_method": CredentialProofingMethod.SWAMID_AL3_MFA, + "proofing_version": "testing", + }, + ], + "profiles": [ + { + "created_by": "test application", + "created_ts": datetime.fromisoformat("2020-02-04T17:42:33.696751"), + "owner": "test owner 1", + "schema": "test schema", + "profile_data": { + "a_string": "I am a string", + "an_int": 3, + "a_list": ["eins", 2, "drei"], + "a_map": {"some": "data"}, + }, + } + ], + "preferredLanguage": "sv", + "surname": "\xf6ne", + "subject": "physical person", + } + ) self._setup_user1() self._setup_user2() - def _setup_user1(self): + def _setup_user1(self) -> None: mailAliases_list = [ MailAddress( created_ts=datetime.fromisoformat("2014-12-18T11:25:19.804000"), @@ -163,7 +168,7 @@ def _setup_user1(self): language="en", ) - def _setup_user2(self): + def _setup_user2(self) -> None: mailAliases_list = [ MailAddress(email="someone+test1@gmail.com", is_verified=True), MailAddress( @@ -225,28 +230,29 @@ def _setup_user2(self): subject=SubjectType("physical person"), ) - def test_user_id(self): + def test_user_id(self) -> None: self.assertEqual(self.user1.user_id, self.data1["_id"]) - def test_eppn(self): + def test_eppn(self) -> None: self.assertEqual(self.user1.eppn, self.data1["eduPersonPrincipalName"]) - def test_given_name(self): + def test_given_name(self) -> None: self.assertEqual(self.user2.given_name, self.data2["givenName"]) - def test_chosen_given_name(self): + def test_chosen_given_name(self) -> None: self.assertEqual(self.user1.chosen_given_name, self.data1["chosen_given_name"]) - def test_surname(self): + def test_surname(self) -> None: self.assertEqual(self.user2.surname, self.data2["surname"]) - def test_legal_name(self): + def test_legal_name(self) -> None: self.assertEqual(self.user1.legal_name, self.data1["legal_name"]) - def test_mail_addresses(self): + def test_mail_addresses(self) -> None: + assert self.user1.mail_addresses.primary is not None self.assertEqual(self.user1.mail_addresses.primary.email, self.data1["mailAliases"][0]["email"]) - def test_passwords(self): + def test_passwords(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -260,7 +266,7 @@ def test_passwords(self): assert obtained == expected - def test_unknown_attributes(self): + def test_unknown_attributes(self) -> None: """ Test parsing a document with unknown data in it. """ @@ -270,16 +276,18 @@ def test_unknown_attributes(self): with self.assertRaises(ValidationError): User.from_dict(data) - def test_incomplete_signup_user(self): + def test_incomplete_signup_user(self) -> None: """ Test parsing the incomplete documents left in the central userdb by older Signup application. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "vohon-mufus", - "mail": "olle@example.org", - "mailAliases": [{"email": "olle@example.org", "verified": False}], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "vohon-mufus", + "mail": "olle@example.org", + "mailAliases": [{"email": "olle@example.org", "verified": False}], + } + ) with self.assertRaises(UserHasNotCompletedSignup): User.from_dict(data) data["subject"] = "physical person" # later signup added this attribute @@ -304,108 +312,122 @@ def test_incomplete_signup_user(self): assert obtained == expected - def test_revoked_user(self): + def test_revoked_user(self) -> None: """ Test ability to identify revoked users. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "binib-mufus", - "revoked_ts": datetime.fromisoformat("2015-05-26T08:33:56.826000"), - "passwords": [], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "binib-mufus", + "revoked_ts": datetime.fromisoformat("2015-05-26T08:33:56.826000"), + "passwords": [], + } + ) with self.assertRaises(UserIsRevoked): User.from_dict(data) - def test_user_with_no_primary_mail(self): + def test_user_with_no_primary_mail(self) -> None: mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mailAliases": [{"email": mail, "verified": True}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mailAliases": [{"email": mail, "verified": True}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_user_with_indirectly_verified_primary_mail(self): + def test_user_with_indirectly_verified_primary_mail(self) -> None: """ If a user has passwords set, the 'mail' attribute will be considered indirectly verified. """ mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mail": mail, - "mailAliases": [{"email": mail, "verified": False}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mail": mail, + "mailAliases": [{"email": mail, "verified": False}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_user_with_indirectly_verified_primary_mail_and_explicit_primary_mail(self): + def test_user_with_indirectly_verified_primary_mail_and_explicit_primary_mail(self) -> None: """ If a user has manage to verify a mail address in the new style with the same address still set in old style mail property. Do not make old mail address primary if a primary all ready exists. """ old_mail = "yahoo@example.com" new_mail = "not_yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mail": old_mail, - "mailAliases": [ - {"email": old_mail, "verified": True, "primary": False}, - {"email": new_mail, "verified": True, "primary": True}, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mail": old_mail, + "mailAliases": [ + {"email": old_mail, "verified": True, "primary": False}, + {"email": new_mail, "verified": True, "primary": True}, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(new_mail, user.mail_addresses.primary.email) - def test_user_with_csrf_junk_in_mail_address(self): + def test_user_with_csrf_junk_in_mail_address(self) -> None: """ For a long time, Dashboard leaked CSRF tokens into the mail address dicts. """ mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "mailAliases": [{"email": mail, "verified": True, "csrf": "6ae1d4e95305b72318a683883e70e3b8e302cd75"}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "mailAliases": [{"email": mail, "verified": True, "csrf": "6ae1d4e95305b72318a683883e70e3b8e302cd75"}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_to_dict(self): + def test_to_dict(self) -> None: """ Test that User objects can be recreated. """ @@ -414,7 +436,7 @@ def test_to_dict(self): d2 = u2.to_dict() self.assertEqual(d1, d2) - def test_modified_ts(self): + def test_modified_ts(self) -> None: """ Test the modified_ts property. """ @@ -428,148 +450,158 @@ def test_modified_ts(self): self.user1.modified_ts = datetime.utcnow() self.assertNotEqual(_time2, self.user1.modified_ts) - def test_two_unverified_non_primary_phones(self): + def test_two_unverified_non_primary_phones(self) -> None: """ Test that the first entry in the `phone' list is chosen as primary when none are verified. """ number1 = "+9112345678" number2 = "+9123456789" - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number1, - "primary": False, - "verified": False, - }, - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number2, - "primary": False, - "verified": False, - }, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number1, + "primary": False, + "verified": False, + }, + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number2, + "primary": False, + "verified": False, + }, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) self.assertEqual(user.phone_numbers.primary, None) - def test_two_non_primary_phones(self): + def test_two_non_primary_phones(self) -> None: """ Test that the first verified number is chosen as primary, if there is a verified number. """ number1 = "+9112345678" number2 = "+9123456789" - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number1, - "primary": False, - "verified": False, - }, - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number2, - "primary": False, - "verified": True, - }, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number1, + "primary": False, + "verified": False, + }, + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number2, + "primary": False, + "verified": True, + }, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) + assert user.phone_numbers.primary self.assertEqual(user.phone_numbers.primary.number, number2) - def test_primary_non_verified_phone(self): + def test_primary_non_verified_phone(self) -> None: """ Test that if a non verified phone number is primary, due to earlier error, then that primary flag is removed. """ - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": "+9112345678", - "primary": True, - "verified": False, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": "+9112345678", + "primary": True, + "verified": False, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) for number in user.phone_numbers.to_list(): self.assertEqual(number.is_primary, False) - def test_primary_non_verified_phone2(self): + def test_primary_non_verified_phone2(self) -> None: """ Test that if a non verified phone number is primary, due to earlier error, then that primary flag is removed. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "pohig-test", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - {"number": "+11111111111", "primary": True, "verified": False}, - {"number": "+22222222222", "primary": False, "verified": True}, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "id": ObjectId(), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "pohig-test", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + {"number": "+11111111111", "primary": True, "verified": False}, + {"number": "+22222222222", "primary": False, "verified": True}, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "id": ObjectId(), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.phone_numbers.primary self.assertEqual(user.phone_numbers.primary.number, "+22222222222") - def test_user_tou_no_created_ts(self): + def test_user_tou_no_created_ts(self) -> None: """ Basic test for user ToU. """ @@ -587,7 +619,7 @@ def test_user_tou_no_created_ts(self): # attr set to True, and therefore the to_dict method will wipe out the created_ts key self.assertFalse(user.tou.has_accepted("1", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) - def test_user_tou(self): + def test_user_tou(self) -> None: """ Basic test for user ToU. """ @@ -605,7 +637,7 @@ def test_user_tou(self): self.assertTrue(user.tou.has_accepted("1", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) self.assertFalse(user.tou.has_accepted("2", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) - def test_locked_identity_load(self): + def test_locked_identity_load(self) -> None: created_ts = datetime.fromisoformat("2013-09-02T10:23:25") locked_identity = { "created_by": "test", @@ -624,7 +656,7 @@ def test_locked_identity_load(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_load_legacy_format(self): + def test_locked_identity_load_legacy_format(self) -> None: created_ts = datetime.fromisoformat("2013-09-02T10:23:25") locked_identity = { "created_by": "test", @@ -642,7 +674,7 @@ def test_locked_identity_load_legacy_format(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_set(self): + def test_locked_identity_set(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -658,14 +690,14 @@ def test_locked_identity_set(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_set_not_verified(self): + def test_locked_identity_set_not_verified(self) -> None: locked_identity = {"created_by": "test", "identity_type": IdentityType.NIN.value, "number": "197801012345"} user = User.from_dict(self.data1) locked_nin = NinIdentity(number=locked_identity["number"], created_by=locked_identity["created_by"]) with pytest.raises(ValidationError): user.locked_identity.add(locked_nin) - def test_locked_identity_to_dict(self): + def test_locked_identity_to_dict(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -688,7 +720,7 @@ def test_locked_identity_to_dict(self): assert new_user.locked_identity.nin.number == "197801012345" assert new_user.locked_identity.nin.is_verified is True - def test_locked_identity_remove(self): + def test_locked_identity_remove(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -699,7 +731,7 @@ def test_locked_identity_remove(self): with self.assertRaises(EduIDUserDBError): user.locked_identity.remove(locked_nin.key) - def test_orcid(self): + def test_orcid(self) -> None: id_token = { "aud": ["APP_ID"], "auth_time": 1526389879, @@ -727,7 +759,8 @@ def test_orcid(self): user.orcid = orcid_element old_user = User.from_dict(user.to_dict()) - self.assertIsNotNone(old_user.orcid) + assert old_user + assert old_user.orcid self.assertIsInstance(old_user.orcid.created_by, str) self.assertIsInstance(old_user.orcid.created_ts, datetime) self.assertIsInstance(old_user.orcid.id, str) @@ -735,25 +768,26 @@ def test_orcid(self): self.assertIsInstance(old_user.orcid.oidc_authz.id_token, OidcIdToken) new_user = User.from_dict(user.to_dict()) - self.assertIsNotNone(new_user.orcid) + assert new_user + assert new_user.orcid self.assertIsInstance(new_user.orcid.created_by, str) self.assertIsInstance(new_user.orcid.created_ts, datetime) self.assertIsInstance(new_user.orcid.id, str) self.assertIsInstance(new_user.orcid.oidc_authz, OidcAuthorization) self.assertIsInstance(new_user.orcid.oidc_authz.id_token, OidcIdToken) - def test_profiles(self): + def test_profiles(self) -> None: self.assertIsNotNone(self.user1.profiles) self.assertEqual(self.user1.profiles.count, 0) self.assertIsNotNone(self.user2.profiles) self.assertEqual(self.user2.profiles.count, 1) - def test_user_verified_credentials(self): + def test_user_verified_credentials(self) -> None: ver = [x for x in self.user2.credentials.to_list() if x.is_verified] keys = [x.key for x in ver] self.assertEqual(keys, [_keyid("U2F SWAMID AL3" + "foo")]) - def test_user_unverified_credential(self): + def test_user_unverified_credential(self) -> None: cred = [x for x in self.user2.credentials.to_list() if x.is_verified][0] self.assertEqual(cred.proofing_method, CredentialProofingMethod.SWAMID_AL3_MFA) _dict1 = cred.to_dict() @@ -766,63 +800,67 @@ def test_user_unverified_credential(self): self.assertFalse("proofing_method" in _dict2) self.assertFalse("proofing_version" in _dict2) - def test_both_mobile_and_phone(self): + def test_both_mobile_and_phone(self) -> None: """Test user that has both 'mobile' and 'phone'""" phone = [ {"number": "+4673123", "primary": True, "verified": True}, {"created_by": "phone", "number": "+4670999", "primary": False, "verified": False}, ] user = User.from_dict( - data={ - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "passwords": [], - "mobile": [{"mobile": "+4673123", "primary": True, "verified": True}], - "phone": phone, - } + data=TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "passwords": [], + "mobile": [{"mobile": "+4673123", "primary": True, "verified": True}], + "phone": phone, + } + ) ) out = user.to_dict()["phone"] assert phone == out, "The phone objects differ when using both phone and mobile" - def test_both_sn_and_surname(self): + def test_both_sn_and_surname(self) -> None: """Test user that has both 'sn' and 'surname'""" user = User.from_dict( - data={ - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "passwords": [], - "surname": "Right", - "sn": "Wrong", - } + data=TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "passwords": [], + "surname": "Right", + "sn": "Wrong", + } + ) ) self.assertEqual("Right", user.to_dict()["surname"]) - def test_terminated_user(self): + def test_terminated_user(self) -> None: data = self.user1.to_dict() data["terminated"] = utc_now() user = User.from_dict(data) assert user.terminated is not None assert isinstance(user.terminated, datetime) is True - def test_terminated_user_false(self): + def test_terminated_user_false(self) -> None: # users can have terminated set to False due to a bug in the past data = self.user1.to_dict() data["terminated"] = False user = User.from_dict(data) assert user.terminated is None - def test_rebuild_user1(self): + def test_rebuild_user1(self) -> None: data = self.user1.to_dict() new_user1 = User.from_dict(data) self.assertEqual(new_user1.eppn, "guvat-nalif") - def test_rebuild_user2(self): + def test_rebuild_user2(self) -> None: data = self.user2.to_dict() new_user2 = User.from_dict(data) self.assertEqual(new_user2.eppn, "birub-gagoz") - def test_mail_addresses_from_dict(self): + def test_mail_addresses_from_dict(self) -> None: """ Test that we get back a correct list of dicts for old-style userdb data. """ @@ -851,7 +889,7 @@ def test_mail_addresses_from_dict(self): assert to_dict_output == mailAliases_list - def test_phone_numbers_from_dict(self): + def test_phone_numbers_from_dict(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -867,7 +905,7 @@ def test_phone_numbers_from_dict(self): to_dict_result = phone_numbers.to_list_of_dicts() assert to_dict_result == phone_list - def test_passwords_from_dict(self): + def test_passwords_from_dict(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -902,7 +940,7 @@ def test_passwords_from_dict(self): assert to_dict_result == expected - def test_phone_numbers(self): + def test_phone_numbers(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -918,7 +956,7 @@ def test_phone_numbers(self): assert obtained == expected - def test_user_meta(self): + def test_user_meta(self) -> None: version = ObjectId() _utc_now = utc_now() user_dict = self.user1.to_dict() @@ -938,21 +976,21 @@ def test_user_meta(self): } assert user_dict2["meta"] == expected - def test_user_meta_version(self): + def test_user_meta_version(self) -> None: assert self.user1.meta.is_in_database is False assert self.user1.meta.version is None self.user1.meta.new_version() assert self.user1.meta.is_in_database is False assert isinstance(self.user1.meta.version, ObjectId) is True - def test_user_meta_modified_ts(self): + def test_user_meta_modified_ts(self) -> None: assert self.user1.meta.modified_ts is None # TODO: remove below check when removing user.modified_ts # verify that user.modified_ts is synced with meta.modified_ts self.user1.modified_ts = utc_now() assert self.user1.meta.modified_ts == self.user1.modified_ts - def test_letter_proofing_data_to_list(self): + def test_letter_proofing_data_to_list(self) -> None: letter_proofing = { "created_by": "eduid-idproofing-letter", "created_ts": datetime(2015, 12, 18, 12, 0, 46), @@ -981,14 +1019,14 @@ def test_letter_proofing_data_to_list(self): user = User.from_dict(user_dict) assert user.to_dict()["letter_proofing_data"] == [letter_proofing] - def test_nins_and_identities_on_user(self): + def test_nins_and_identities_on_user(self) -> None: user_dict = UserFixtures().mocked_user_standard.to_dict() assert user_dict["identities"] != [] user_dict = User.from_dict(user_dict).to_dict() assert user_dict.get("nins") is None assert user_dict.get("identities") is not None - def test_empty_nins_list(self): + def test_empty_nins_list(self) -> None: user_dict = UserFixtures().mocked_user_standard.to_dict() del user_dict["identities"] user_dict["nins"] = [] diff --git a/src/eduid/userdb/tests/test_userdb.py b/src/eduid/userdb/tests/test_userdb.py index cc8448672..fd7ca4665 100644 --- a/src/eduid/userdb/tests/test_userdb.py +++ b/src/eduid/userdb/tests/test_userdb.py @@ -5,26 +5,32 @@ from eduid.common.testing_base import normalised_data from eduid.userdb import User +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.exceptions import UserDoesNotExist, UserOutOfSync from eduid.userdb.fixtures.passwords import signup_password from eduid.userdb.fixtures.users import UserFixtures -from eduid.userdb.testing import MongoTestCase +from eduid.userdb.testing import MongoTestCase, SetupConfig from eduid.userdb.util import format_dict_for_debug logger = logging.getLogger(__name__) class TestUserDB(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, config: SetupConfig | None = None) -> None: self.user = UserFixtures().mocked_user_standard - super().setUp(am_users=[self.user], **kwargs) + if config is None: + config = SetupConfig() + config.am_users = [self.user] + super().setUp(config=config) - def test_get_user_by_id(self): + def test_get_user_by_id(self) -> None: """Test get_user_by_id""" res = self.amdb.get_user_by_id(self.user.user_id) + assert res assert self.user.user_id == res.user_id res = self.amdb.get_user_by_id(str(self.user.user_id)) + assert res assert self.user.user_id == res.user_id res = self.amdb.get_user_by_id(str(bson.ObjectId())) @@ -33,7 +39,7 @@ def test_get_user_by_id(self): res = self.amdb.get_user_by_id("not-a-valid-object-id") assert res is None - def test_get_user_by_nin(self): + def test_get_user_by_nin(self) -> None: """Test get_user_by_nin""" test_user = self.amdb.get_user_by_id(self.user.user_id) assert test_user is not None @@ -44,25 +50,28 @@ def test_get_user_by_nin(self): assert res is not None assert test_user.given_name == res.given_name - def test_remove_user_by_id(self): + def test_remove_user_by_id(self) -> None: """Test removing a user from the database NOTE: remove_user_by_id() should be moved to SignupUserDb """ test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user + assert test_user.identities.nin res = self.amdb.get_users_by_nin(test_user.identities.nin.number) assert normalised_data(res[0].to_dict()) == normalised_data(test_user.to_dict()) self.amdb.remove_user_by_id(test_user.user_id) res = self.amdb.get_users_by_nin(test_user.identities.nin.number) assert res == [] - def test_get_user_by_eppn(self): + def test_get_user_by_eppn(self) -> None: """Test user lookup using eppn""" test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user res = self.amdb.get_user_by_eppn(test_user.eppn) assert test_user.user_id == res.user_id - def test_get_user_by_eppn_not_found(self): + def test_get_user_by_eppn_not_found(self) -> None: """Test user lookup using unknown""" with pytest.raises(UserDoesNotExist): self.amdb.get_user_by_eppn("abc123") @@ -71,9 +80,12 @@ def test_get_user_by_eppn_not_found(self): class UserMissingMeta(MongoTestCase): user: User - def setUp(self, *args, **kwargs): + def setUp(self, config: SetupConfig | None = None) -> None: self.user = UserFixtures().mocked_user_standard - super().setUp(*args, am_users=[self.user], **kwargs) + if config is None: + config = SetupConfig() + config.am_users = [self.user] + super().setUp(config=config) self._remove_meta_from_user_in_db(self.user) @@ -88,7 +100,7 @@ def _remove_meta_from_user_in_db(self, user: User) -> None: user_doc.pop("meta") self.amdb._coll.replace_one({"_id": user.user_id}, user_doc) - def test_update_user_new(self): + def test_update_user_new(self) -> None: db_user = self.amdb.get_user_by_id(self.user.user_id) assert db_user is not None logger.debug(f"Loaded user.meta from database:\n{format_dict_for_debug(db_user.meta.dict())}") @@ -96,19 +108,23 @@ def test_update_user_new(self): db_user.given_name = "test" self.amdb.save(user=db_user) - def test_update_user_old(self): + def test_update_user_old(self) -> None: db_user = self.amdb.get_user_by_id(self.user.user_id) + assert db_user db_user.given_name = "test" self.amdb.save(user=db_user) class UpdateUser(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, config: SetupConfig | None = None) -> None: _users = UserFixtures() self.user = _users.mocked_user_standard - super().setUp(am_users=[self.user, _users.mocked_user_standard_2], **kwargs) + if config is None: + config = SetupConfig() + config.am_users = [self.user, _users.mocked_user_standard_2] + super().setUp(config=config) - def test_stale_user_meta_version(self): + def test_stale_user_meta_version(self) -> None: test_user = self.amdb.get_user_by_eppn(self.user.eppn) test_user.given_name = "new_given_name" test_user.meta.new_version() @@ -116,8 +132,9 @@ def test_stale_user_meta_version(self): with self.assertRaises(UserOutOfSync): self.amdb.save(test_user) - def test_ok(self): + def test_ok(self) -> None: test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user test_user.given_name = "new_given_name" old_meta_version = test_user.meta.version @@ -127,31 +144,36 @@ def test_ok(self): assert res.success is True db_user = self.amdb.get_user_by_id(test_user.user_id) + assert db_user assert db_user.meta.version != old_meta_version assert db_user.modified_ts != old_modified_ts assert db_user.given_name == "new_given_name" class TestUserDB_mail(MongoTestCase): - def setUp(self, *args, **kwargs): - super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "mail-test1", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "passwords": [signup_password.to_dict()], - } - - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "mail-test2", - "mailAliases": [ - {"email": "test2@gmail.com", "primary": True, "verified": True}, - {"email": "test@gmail.com", "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "mail-test1", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "passwords": [signup_password.to_dict()], + } + ) + + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "mail-test2", + "mailAliases": [ + {"email": "test2@gmail.com", "primary": True, "verified": True}, + {"email": "test@gmail.com", "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) @@ -159,16 +181,19 @@ def setUp(self, *args, **kwargs): self.amdb.save(self.user1) self.amdb.save(self.user2) - def test_get_user_by_mail(self): + def test_get_user_by_mail(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.mail_addresses.primary res = self.amdb.get_user_by_mail(test_user.mail_addresses.primary.email) + assert res assert test_user.user_id == res.user_id - def test_get_user_by_mail_unknown(self): + def test_get_user_by_mail_unknown(self) -> None: """Test searching for unknown e-mail address""" assert self.amdb.get_user_by_mail("abc123@example.edu") is None - def test_get_user_by_mail_multiple(self): + def test_get_user_by_mail_multiple(self) -> None: res = self.amdb.get_users_by_mail("test@gmail.com") ids = [x.user_id for x in res] assert ids == [self.user1.user_id] @@ -179,53 +204,61 @@ def test_get_user_by_mail_multiple(self): class TestUserDB_phone(MongoTestCase): - def setUp(self, *args, **kwargs): - super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "phone-test1", - "mail": "kalle@example.com", - "phone": [ - {"number": "+11111111111", "primary": True, "verified": True}, - {"number": "+22222222222", "primary": False, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "phone-test2", - "mail": "anka@example.com", - "phone": [ - {"number": "+11111111111", "primary": True, "verified": False}, - {"number": "+22222222222", "primary": False, "verified": False}, - {"number": "+33333333333", "primary": False, "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "phone-test1", + "mail": "kalle@example.com", + "phone": [ + {"number": "+11111111111", "primary": True, "verified": True}, + {"number": "+22222222222", "primary": False, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "phone-test2", + "mail": "anka@example.com", + "phone": [ + {"number": "+11111111111", "primary": True, "verified": False}, + {"number": "+22222222222", "primary": False, "verified": False}, + {"number": "+33333333333", "primary": False, "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) self.amdb.save(self.user1) self.amdb.save(self.user2) - def test_get_user_by_phone(self): + def test_get_user_by_phone(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.phone_numbers.primary res = self.amdb.get_user_by_phone(test_user.phone_numbers.primary.number) + assert res assert res.user_id == test_user.user_id res = self.amdb.get_user_by_phone("+22222222222") + assert res assert res.user_id == test_user.user_id assert self.amdb.get_user_by_phone("+33333333333") is None - res = self.amdb.get_users_by_phone("+33333333333", include_unconfirmed=True) - assert [x.user_id for x in res] == [self.user2.user_id] + res_list = self.amdb.get_users_by_phone("+33333333333", include_unconfirmed=True) + assert [x.user_id for x in res_list] == [self.user2.user_id] - def test_get_user_by_phone_unknown(self): + def test_get_user_by_phone_unknown(self) -> None: """Test searching for unknown e-phone address""" assert self.amdb.get_user_by_phone("abc123@example.edu") is None - def test_get_user_by_phone_multiple(self): + def test_get_user_by_phone_multiple(self) -> None: res = self.amdb.get_users_by_phone("+11111111111") ids = [x.user_id for x in res] assert ids == [self.user1.user_id] @@ -237,36 +270,41 @@ def test_get_user_by_phone_multiple(self): class TestUserDB_nin(MongoTestCase): # TODO: Keep for a while to make sure the conversion to identities work as expected - def setUp(self, *args, **kwargs): - super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test1", - "mail": "kalle@example.com", - "nins": [ - {"number": "11111111111", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test2", - "mail": "anka@example.com", - "nins": [ - {"number": "22222222222", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - - data3 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test3", - "mail": "anka@example.com", - "nins": [ - {"number": "33333333333", "primary": False, "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test1", + "mail": "kalle@example.com", + "nins": [ + {"number": "11111111111", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test2", + "mail": "anka@example.com", + "nins": [ + {"number": "22222222222", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data3: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test3", + "mail": "anka@example.com", + "nins": [ + {"number": "33333333333", "primary": False, "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) @@ -276,37 +314,44 @@ def setUp(self, *args, **kwargs): self.amdb.save(self.user2) self.amdb.save(self.user3) - def test_get_user_by_nin(self): + def test_get_user_by_nin(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.identities.nin res = self.amdb.get_user_by_nin(test_user.identities.nin.number) + assert res assert res.user_id == test_user.user_id, "alpha" res = self.amdb.get_user_by_nin("11111111111") + assert res assert res.user_id == test_user.user_id, "beta" res = self.amdb.get_user_by_nin("22222222222") + assert res assert res.user_id == self.user2.user_id, "gamma" assert self.amdb.get_user_by_nin("33333333333") is None, "delta" - res = self.amdb.get_users_by_nin("33333333333", include_unconfirmed=True) - assert [x.user_id for x in res] == [self.user3.user_id], "epsilon" + res_list = self.amdb.get_users_by_nin("33333333333", include_unconfirmed=True) + assert [x.user_id for x in res_list] == [self.user3.user_id], "epsilon" - def test_get_user_by_nin_unknown(self): + def test_get_user_by_nin_unknown(self) -> None: """Test searching for unknown e-nin address""" assert self.amdb.get_user_by_nin("77777777777") is None - def test_get_user_by_nin_multiple(self): + def test_get_user_by_nin_multiple(self) -> None: # create another user with nin 33333333333, this one verified - data4 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test4", - "mail": "anka@example.com", - "nins": [ - {"number": "33333333333", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } + data4 = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test4", + "mail": "anka@example.com", + "nins": [ + {"number": "33333333333", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) user4 = User.from_dict(data4) self.amdb.save(user4) diff --git a/src/eduid/userdb/tests/test_webauthn.py b/src/eduid/userdb/tests/test_webauthn.py index dfb41a39a..facbf1449 100644 --- a/src/eduid/userdb/tests/test_webauthn.py +++ b/src/eduid/userdb/tests/test_webauthn.py @@ -2,7 +2,7 @@ from hashlib import sha256 from unittest import TestCase -from eduid.userdb.credentials import CredentialList, CredentialProofingMethod +from eduid.userdb.credentials import CredentialList, CredentialProofingMethod, Webauthn __author__ = "lundberg" @@ -23,22 +23,24 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]) -> str: return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["credential_data"].encode("utf-8")).hexdigest() class TestWebauthn(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the credential. """ this = self.one.find(_keyid(_one_dict)) + assert this + assert isinstance(this, Webauthn) self.assertEqual( this.key, _keyid( @@ -49,7 +51,7 @@ def test_key(self): ), ) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -57,17 +59,20 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_proofing_method(self): + def test_proofing_method(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI self.assertEqual(this.proofing_method, CredentialProofingMethod.SWAMID_AL2_MFA_HI) this.proofing_method = CredentialProofingMethod.SWAMID_AL3_MFA @@ -75,8 +80,9 @@ def test_proofing_method(self): this.proofing_method = None self.assertEqual(this.proofing_method, None) - def test_proofing_version(self): + def test_proofing_version(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_version = "TEST" self.assertEqual(this.proofing_version, "TEST") this.proofing_version = "TEST2" diff --git a/src/eduid/userdb/tou.py b/src/eduid/userdb/tou.py index fd83bd4c9..64d222f23 100644 --- a/src/eduid/userdb/tou.py +++ b/src/eduid/userdb/tou.py @@ -21,7 +21,7 @@ class ToUEvent(Event): @field_validator("version") @classmethod - def _validate_tou_version(cls, v): + def _validate_tou_version(cls, v: object) -> str: if not v: raise ValueError("ToU must have a version") if not isinstance(v, str): diff --git a/src/eduid/userdb/user.py b/src/eduid/userdb/user.py index 7f12fbe96..aee68847d 100644 --- a/src/eduid/userdb/user.py +++ b/src/eduid/userdb/user.py @@ -126,7 +126,7 @@ def __str__(self) -> str: return f"" return f"" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: Any) -> bool: # noqa: ANN401 if self.__class__ is not other.__class__: raise TypeError(f"Trying to compare objects of different class {other.__class__} != {self.__class__}") return self.to_dict() == other.to_dict() diff --git a/src/eduid/userdb/user_cleaner/db.py b/src/eduid/userdb/user_cleaner/db.py index 570c47667..87255dab2 100644 --- a/src/eduid/userdb/user_cleaner/db.py +++ b/src/eduid/userdb/user_cleaner/db.py @@ -21,7 +21,7 @@ class CleanerQueueUser(User): class CleanerQueueDB(UserDB[CleanerQueueUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_user_cleaner", collection: str = "cleaner_queue"): + def __init__(self, db_uri: str, db_name: str = "eduid_user_cleaner", collection: str = "cleaner_queue") -> None: super().__init__(db_uri, db_name, collection) indexes = { diff --git a/src/eduid/userdb/user_cleaner/userdb.py b/src/eduid/userdb/user_cleaner/userdb.py index 97844efe1..337e7b9d8 100644 --- a/src/eduid/userdb/user_cleaner/userdb.py +++ b/src/eduid/userdb/user_cleaner/userdb.py @@ -8,7 +8,7 @@ class CleanerUser(User): class CleanerUserDB(UserDB[CleanerUser]): - def __init__(self, db_uri: str, db_name: str = "eduid_user_cleaner", collection: str = "profiles"): + def __init__(self, db_uri: str, db_name: str = "eduid_user_cleaner", collection: str = "profiles") -> None: super().__init__(db_uri, db_name, collection) @classmethod diff --git a/src/eduid/userdb/userdb.py b/src/eduid/userdb/userdb.py index de5d8fea6..e53654d7c 100644 --- a/src/eduid/userdb/userdb.py +++ b/src/eduid/userdb/userdb.py @@ -43,7 +43,7 @@ class UserDB(BaseDB, Generic[UserVar], ABC): :param collection: mongodb collection name """ - def __init__(self, db_uri: str, db_name: str, collection: str = "userdb"): + def __init__(self, db_uri: str, db_name: str, collection: str = "userdb") -> None: if db_name == "eduid_am" and collection == "userdb": # Hack to get right collection name while the configuration points to the old database collection = "attributes" @@ -53,7 +53,7 @@ def __init__(self, db_uri: str, db_name: str, collection: str = "userdb"): logger.debug(f"{self} connected to database") - def __repr__(self): + def __repr__(self) -> str: return f"" __str__ = __repr__ @@ -182,7 +182,7 @@ def get_users_by_nin(self, nin: str, include_unconfirmed: bool = False) -> list[ def get_users_by_identity( self, identity_type: IdentityType, key: str, value: str, include_unconfirmed: bool = False - ): + ) -> list[UserVar]: match = {"identity_type": identity_type.value, key: value, "verified": True} if include_unconfirmed: del match["verified"] @@ -234,7 +234,7 @@ def get_user_by_eppn(self, eppn: str | None) -> UserVar: raise UserDoesNotExist(f"No user with eppn {repr(eppn)}") return res - def _get_user_by_attr(self, attr: str, value: Any) -> UserVar | None: + def _get_user_by_attr(self, attr: str, value: Any) -> UserVar | None: # noqa: ANN401 """ Locate a user in the userdb using any attribute and value. @@ -328,7 +328,7 @@ def update_user(self, obj_id: ObjectId, operations: Mapping[str, Any]) -> None: class AmDB(UserDB[User]): """Central userdb, aka. AM DB""" - def __init__(self, db_uri: str, db_name: str = "eduid_am"): + def __init__(self, db_uri: str, db_name: str = "eduid_am") -> None: super().__init__(db_uri, db_name) @classmethod diff --git a/src/eduid/userdb/util.py b/src/eduid/userdb/util.py index df29d27bf..cd61d4292 100644 --- a/src/eduid/userdb/util.py +++ b/src/eduid/userdb/util.py @@ -14,13 +14,13 @@ class UTC(datetime.tzinfo): """UTC""" - def utcoffset(self, dt): + def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta: return datetime.timedelta(0) - def tzname(self, dt): + def tzname(self, dt: datetime.datetime | None) -> str: return "UTC" - def dst(self, dt): + def dst(self, dt: datetime.datetime | None) -> datetime.timedelta: return datetime.timedelta(0) diff --git a/src/eduid/vccs/client/__init__.py b/src/eduid/vccs/client/__init__.py index e61aa767a..e41ff214d 100644 --- a/src/eduid/vccs/client/__init__.py +++ b/src/eduid/vccs/client/__init__.py @@ -57,7 +57,7 @@ class VCCSClientException(Exception): Base exception class for VCCS client. """ - def __init__(self, reason: str): + def __init__(self, reason: str) -> None: Exception.__init__(self) self.reason = reason @@ -69,11 +69,11 @@ class VCCSClientHTTPError(VCCSClientException): library is used by the VCCS client. """ - def __init__(self, reason: str, http_code: int): + def __init__(self, reason: str, http_code: int) -> None: VCCSClientException.__init__(self, reason) self.http_code = http_code - def __str__(self): + def __str__(self) -> str: return f"<{self.__class__.__name__} instance at {hex(id(self))}: {self.http_code!r} {self.reason!r}>" @@ -82,7 +82,7 @@ class VCCSFactor: Base class for authentication factors. Do not use directly. """ - def __init__(self): + def __init__(self) -> None: pass def to_dict(self, _action: str) -> dict[str, Any]: @@ -101,7 +101,9 @@ class VCCSPasswordFactor(VCCSFactor): Object representing an ordinary password authentication factor. """ - def __init__(self, password: str, credential_id: str, salt: str | None = None, strip_whitespace: bool = True): + def __init__( + self, password: str, credential_id: str, salt: str | None = None, strip_whitespace: bool = True + ) -> None: """ :param password: string, password as plaintext :param credential_id: unique id of credential in the authentication backend database @@ -174,7 +176,7 @@ def _get_random_bytes(self, num_bytes: int) -> bytes: """ return os.urandom(num_bytes) - def to_dict(self, _action): + def to_dict(self, _action: str) -> dict[str, Any]: """ Return factor as dictionary, transmittable to authentiation backends. :param _action: 'auth', 'add_creds' or 'revoke_creds' @@ -193,8 +195,16 @@ class VCCSOathFactor(VCCSFactor): """ def __init__( - self, oath_type, credential_id, user_code=None, nonce=None, aead=None, key_handle=None, digits=6, oath_counter=0 - ): + self, + oath_type: str, + credential_id: int, + user_code: int | None = None, + nonce: str | None = None, + aead: str | None = None, + key_handle: int | None = None, + digits: int = 6, + oath_counter: int = 0, + ) -> None: """ :param oath_type: 'oath-totp' or 'oath-hotp' (time based or event based OATH) :param credential_id: integer, unique index of credential @@ -263,7 +273,7 @@ class VCCSRevokeFactor(VCCSFactor): Object representing a factor to be revoked. """ - def __init__(self, credential_id: str, reason: str, reference: str = ""): + def __init__(self, credential_id: str, reason: str, reference: str = "") -> None: """ :param credential_id: unique index of credential :param reason: reason for revocation @@ -302,7 +312,7 @@ class VCCSClient: credentials (authentication factors). """ - def __init__(self, base_url: str | None = None): + def __init__(self, base_url: str | None = None) -> None: self.base_url = base_url if base_url else "http://localhost:8550/" def authenticate(self, user_id: str, factors: Sequence[VCCSFactor]) -> bool: @@ -362,7 +372,7 @@ def revoke_credentials(self, user_id: str, factors: Sequence[VCCSRevokeFactor]) raise TypeError(f"Operation success value type error : {success!r}") return success is True - def _execute(self, data, response_label: str): + def _execute(self, data: str, response_label: str) -> dict: """ Make a HTTP POST request to the authentication backend, and parse the result. @@ -391,7 +401,7 @@ def _execute(self, data, response_label: str): raise AssertionError(f"Received response of unknown version {resp_ver!r}") return resp[response_label] - def _execute_request_response(self, service: str, values): + def _execute_request_response(self, service: str, values: dict[str, Any]) -> str: """ The part of _execute that has actual side effects. In a separate function to make everything else easily testable. diff --git a/src/eduid/vccs/client/tests/test_client.py b/src/eduid/vccs/client/tests/test_client.py index 8cd7138ec..f177db310 100644 --- a/src/eduid/vccs/client/tests/test_client.py +++ b/src/eduid/vccs/client/tests/test_client.py @@ -1,16 +1,14 @@ -#!/usr/bin/python - -""" -Test VCCS client. -""" - -import os import unittest +from typing import Any import simplejson as json from eduid.vccs.client import VCCSClient, VCCSOathFactor, VCCSPasswordFactor, VCCSRevokeFactor +""" +Test VCCS client. +""" + class FakeVCCSClient(VCCSClient): """ @@ -18,11 +16,11 @@ class FakeVCCSClient(VCCSClient): in order to fake HTTP communication. """ - def __init__(self, fake_response): + def __init__(self, fake_response: str) -> None: self.fake_response = fake_response VCCSClient.__init__(self) - def _execute_request_response(self, service, values): + def _execute_request_response(self, service: str, values: dict[str, Any]) -> str: self.last_service = service self.last_values = values return self.fake_response @@ -33,17 +31,12 @@ class FakeVCCSPasswordFactor(VCCSPasswordFactor): Sub-class that overrides the get_random_bytes function to make certain things testable. """ - def _get_random_bytes(self, num_bytes): - b = os.urandom(1) - if isinstance(b, str): - # Python2 - return chr(0xA) * num_bytes - # Python3 + def _get_random_bytes(self, num_bytes: int) -> bytes: return b"\x0a" * num_bytes class TestVCCSClient(unittest.TestCase): - def test_password_factor(self): + def test_password_factor(self) -> None: """ Test creating a VCCSPasswordFactor instance. """ @@ -58,7 +51,7 @@ def test_password_factor(self): }, ) - def test_utf8_password_factor(self): + def test_utf8_password_factor(self) -> None: """ Test creating a VCCSPasswordFactor instance. """ @@ -73,22 +66,22 @@ def test_utf8_password_factor(self): }, ) - def test_OATH_factor_auth(self): + def test_OATH_factor_auth(self) -> None: """ Test creating a VCCSOathFactor instance. """ aead = "aa" * 20 - o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, user_code=123456) self.assertEqual( o.to_dict("auth"), { "type": "oath-hotp", "credential_id": 4712, - "user_code": "123456", + "user_code": 123456, }, ) - def test_OATH_factor_add(self): + def test_OATH_factor_add(self) -> None: """ Test creating a VCCSOathFactor instance for an add_creds request. """ @@ -107,24 +100,24 @@ def test_OATH_factor_add(self): }, ) - def test_missing_parts_of_OATH_factor(self): + def test_missing_parts_of_OATH_factor(self) -> None: """ Test creating a VCCSOathFactor instance with missing parts. """ aead = "aa" * 20 - o = VCCSOathFactor("oath-hotp", 4712, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, user_code=123456) # missing AEAD with self.assertRaises(ValueError): o.to_dict("add_creds") - o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, key_handle=0x1234, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, key_handle=0x1234, user_code=123456) # with AEAD o should be OK self.assertEqual(type(o.to_dict("add_creds")), dict) # unknown to_dict 'action' should raise with self.assertRaises(ValueError): o.to_dict("bad_action") - def test_authenticate1(self): + def test_authenticate1(self) -> None: """ Test parsing of successful authentication response. """ @@ -138,7 +131,7 @@ def test_authenticate1(self): f = VCCSPasswordFactor("password", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertTrue(c.authenticate("ft@example.net", [f])) - def test_authenticate1_utf8(self): + def test_authenticate1_utf8(self) -> None: """ Test parsing of successful authentication response with a password in UTF-8. """ @@ -152,7 +145,7 @@ def test_authenticate1_utf8(self): f = VCCSPasswordFactor("passwordåäöхэж", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertTrue(c.authenticate("ft@example.net", [f])) - def test_authenticate2(self): + def test_authenticate2(self) -> None: """ Test unknown response version """ @@ -166,7 +159,7 @@ def test_authenticate2(self): with self.assertRaises(AssertionError): c.authenticate("ft@example.net", [f]) - def test_authenticate2_utf8(self): + def test_authenticate2_utf8(self) -> None: """ Test unknown response version with a password in UTF-8. """ @@ -180,7 +173,7 @@ def test_authenticate2_utf8(self): with self.assertRaises(AssertionError): c.authenticate("ft@example.net", [f]) - def test_add_creds1(self): + def test_add_creds1(self) -> None: """ Test parsing of successful add_creds response. """ @@ -208,7 +201,7 @@ def test_add_creds1(self): } self.assertEqual(expected, values) - def test_add_creds1_utf8(self): + def test_add_creds1_utf8(self) -> None: """ Test parsing of successful add_creds response with a password in UTF-8. """ @@ -236,7 +229,7 @@ def test_add_creds1_utf8(self): } self.assertEqual(expected, values) - def test_add_creds2(self): + def test_add_creds2(self) -> None: """ Test parsing of unsuccessful add_creds response. """ @@ -250,7 +243,7 @@ def test_add_creds2(self): f = VCCSPasswordFactor("password", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertFalse(c.add_credentials("ft@example.net", [f])) - def test_add_creds2_utf8(self): + def test_add_creds2_utf8(self) -> None: """ Test parsing of unsuccessful add_creds response with a password in UTF-8. """ @@ -264,7 +257,7 @@ def test_add_creds2_utf8(self): f = VCCSPasswordFactor("passwordåäöхэж", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertFalse(c.add_credentials("ft@example.net", [f])) - def test_revoke_creds1(self): + def test_revoke_creds1(self) -> None: """ Test parsing of unsuccessful revoke_creds response. """ @@ -278,24 +271,24 @@ def test_revoke_creds1(self): r = VCCSRevokeFactor("4712", "testing revoke", "foobar") self.assertFalse(c.revoke_credentials("ft@example.net", [r])) - def test_revoke_creds2(self): + def test_revoke_creds2(self) -> None: """ Test revocation reason/reference bad types. """ - FakeVCCSClient(None) + FakeVCCSClient("Fake response not used in test") with self.assertRaises(TypeError): - VCCSRevokeFactor(4712, 1234, "foobar") + VCCSRevokeFactor(4712, 1234, "foobar") # type: ignore[arg-type] with self.assertRaises(TypeError): - VCCSRevokeFactor(4712, "foobar", 2345) + VCCSRevokeFactor(4712, "foobar", 2345) # type: ignore[arg-type] - def test_unknown_salt_version(self): + def test_unknown_salt_version(self) -> None: """Test unknown salt version""" with self.assertRaises(ValueError): VCCSPasswordFactor("anything", "4711", "$NDNvFOO$aaaaaaaaaaaaaaaa$12$32$") - def test_generate_salt1(self): + def test_generate_salt1(self) -> None: """Test salt generation.""" f = VCCSPasswordFactor("anything", "4711") self.assertEqual(len(f.salt), 80) @@ -304,7 +297,7 @@ def test_generate_salt1(self): self.assertEqual(rounds, 32) self.assertEqual(len(random), length) - def test_generate_salt2(self): + def test_generate_salt2(self) -> None: """Test salt generation with fake RNG.""" f = FakeVCCSPasswordFactor("anything", "4711") diff --git a/src/eduid/vccs/server/db.py b/src/eduid/vccs/server/db.py index 4c3c1ef23..429ee83d5 100644 --- a/src/eduid/vccs/server/db.py +++ b/src/eduid/vccs/server/db.py @@ -176,7 +176,7 @@ def from_dict_backwards_compat(cls: type[RevokedCredential], data: Mapping[str, class CredentialDB(BaseDB): - def __init__(self, db_uri: str, db_name: str = "vccs_auth_credstore", collection: str = "credentials"): + def __init__(self, db_uri: str, db_name: str = "vccs_auth_credstore", collection: str = "credentials") -> None: super().__init__(db_uri, db_name, collection=collection) indexes = { diff --git a/src/eduid/vccs/server/endpoints/add_creds.py b/src/eduid/vccs/server/endpoints/add_creds.py index 37a3c7dc0..542cd1d8d 100644 --- a/src/eduid/vccs/server/endpoints/add_creds.py +++ b/src/eduid/vccs/server/endpoints/add_creds.py @@ -67,7 +67,9 @@ async def add_creds(req: Request, request: AddCredsRequestV1) -> AddCredsRespons return response -async def _add_password_credential(_config, factor, req, request): +async def _add_password_credential( + _config: VCCSConfig, factor: RequestFactor, req: Request, request: AddCredsRequestV1 +) -> bool: _salt = (await req.app.state.hasher.safe_random(_config.add_creds_password_salt_bytes)).hex() cred = PasswordCredential( credential_id=factor.credential_id, diff --git a/src/eduid/vccs/server/endpoints/misc.py b/src/eduid/vccs/server/endpoints/misc.py index 3345c5836..210fa0e15 100644 --- a/src/eduid/vccs/server/endpoints/misc.py +++ b/src/eduid/vccs/server/endpoints/misc.py @@ -43,6 +43,6 @@ class HMACResponse(BaseModel): @misc_router.get("/hmac/{keyhandle}/{data}", response_model=HMACResponse) -async def hmac(request: Request, keyhandle: int, data: bytes): +async def hmac(request: Request, keyhandle: int, data: bytes) -> HMACResponse: hmac = await request.app.state.hasher.hmac_sha1(key_handle=keyhandle, data=data) return HMACResponse(keyhandle=keyhandle, hmac=hmac.hex()) diff --git a/src/eduid/vccs/server/hasher.py b/src/eduid/vccs/server/hasher.py index 68b460fe7..e998ed9de 100644 --- a/src/eduid/vccs/server/hasher.py +++ b/src/eduid/vccs/server/hasher.py @@ -2,46 +2,62 @@ import os import stat from abc import ABC +from asyncio.locks import Lock from binascii import unhexlify from collections.abc import Mapping from hashlib import sha1 -from typing import Any +from typing import Literal import pyhsm import yaml +class NoOpLock: + """ + A No-op lock class, to avoid a lot of "if self.lock:" in code using locks. + """ + + def __init__(self) -> None: + pass + + async def acquire(self) -> None: + pass + + async def release(self) -> None: + pass + + class VCCSHasher(ABC): - def __init__(self, lock): + def __init__(self, lock: Lock | NoOpLock) -> None: self.lock = lock def unlock(self, password: str) -> None: raise NotImplementedError("Subclass should implement unlock") - def info(self) -> Any: + def info(self) -> str | bytes | None: raise NotImplementedError("Subclass should implement info") - def hmac_sha1(self, _key_handle, _data): + async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: raise NotImplementedError("Subclass should implement safe_hmac_sha1") - def unsafe_hmac_sha1(self, _key_handle, _data): + def unsafe_hmac_sha1(self, key_handle: int | None, _data: bytes) -> bytes: raise NotImplementedError("Subclass should implement hmac_sha1") - def load_temp_key(self, _nonce, _key_handle, _aead): + def load_temp_key(self, nonce: str, _key_handle: int, _aead: bytes) -> bool: raise NotImplementedError("Subclass should implement load_temp_key") - def safe_random(self, _byte_count): + async def safe_random(self, byte_count: int) -> bytes: raise NotImplementedError("Subclass should implement safe_random") - async def lock_acquire(self): + async def lock_acquire(self) -> Literal[True] | None: return await self.lock.acquire() - async def lock_release(self): - return self.lock.release() + def lock_release(self) -> None: + self.lock.release() class VCCSYHSMHasher(VCCSHasher): - def __init__(self, device, lock, debug=False): + def __init__(self, device: str, lock: Lock | NoOpLock, debug: bool = False) -> None: VCCSHasher.__init__(self, lock) self._yhsm = pyhsm.base.YHSM(device, debug) @@ -49,10 +65,11 @@ def unlock(self, password: str) -> None: """Unlock YubiHSM on startup. The password is supposed to be hex encoded.""" self._yhsm.unlock(unhexlify(password)) - def info(self) -> Any: - return self._yhsm.info() + def info(self) -> str: + ret: bytes = self._yhsm.info() + return ret.decode() - async def hmac_sha1(self, key_handle: int, data: bytes) -> bytes: + async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: """ Perform HMAC-SHA-1 operation using YubiHSM. @@ -62,14 +79,14 @@ async def hmac_sha1(self, key_handle: int, data: bytes) -> bytes: try: return self.unsafe_hmac_sha1(key_handle, data) finally: - await self.lock_release() + self.lock_release() - def unsafe_hmac_sha1(self, key_handle: int, data: bytes) -> bytes: + def unsafe_hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: if key_handle is None: key_handle = pyhsm.defines.YSM_TEMP_KEY_HANDLE return self._yhsm.hmac_sha1(key_handle, data).get_hash() - def load_temp_key(self, nonce, key_handle, aead): + def load_temp_key(self, nonce: str, key_handle: int, aead: bytes) -> bool: return self._yhsm.load_temp_key(nonce, key_handle, aead) async def safe_random(self, byte_count: int) -> bytes: @@ -85,7 +102,7 @@ async def safe_random(self, byte_count: int) -> bytes: xored = bytes([a ^ b for (a, b) in zip(from_hsm, from_os)]) return xored finally: - await self.lock_release() + self.lock_release() class VCCSSoftHasher(VCCSHasher): @@ -94,7 +111,7 @@ class VCCSSoftHasher(VCCSHasher): (except perhaps separating HMAC keys from credential store). """ - def __init__(self, keys: Mapping[int, str], lock, debug=False): + def __init__(self, keys: Mapping[int, str], lock: Lock | NoOpLock, debug: bool = False) -> None: super().__init__(lock) self.debug = debug # Covert keys from strings to bytes when loading @@ -106,7 +123,7 @@ def __init__(self, keys: Mapping[int, str], lock, debug=False): def unlock(self, password: str) -> None: return None - def info(self) -> Any: + def info(self) -> str: return f"key handles loaded: {list(self.keys.keys())}" async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: @@ -119,7 +136,7 @@ async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: try: return self.unsafe_hmac_sha1(key_handle, data) finally: - await self.lock_release() + self.lock_release() def unsafe_hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: if key_handle is None: @@ -142,22 +159,9 @@ async def safe_random(self, byte_count: int) -> bytes: return os.urandom(byte_count) -class NoOpLock: - """ - A No-op lock class, to avoid a lot of "if self.lock:" in code using locks. - """ - - def __init__(self): - pass - - async def acquire(self): - pass - - async def release(self): - pass - - -def hasher_from_string(name: str, lock=None, debug=False): +def hasher_from_string( + name: str, lock: Lock | NoOpLock | None = None, debug: bool = False +) -> VCCSSoftHasher | VCCSYHSMHasher: """ Create a hasher instance from a name. Name can currently only be a name of a YubiHSM device, such as '/dev/ttyACM0'. diff --git a/src/eduid/vccs/server/log.py b/src/eduid/vccs/server/log.py index 6babb08df..ba96a5c14 100644 --- a/src/eduid/vccs/server/log.py +++ b/src/eduid/vccs/server/log.py @@ -1,12 +1,14 @@ import logging import sys +from loguru import Logger from loguru import logger as loguru_logger class InterceptHandler(logging.Handler): - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: # Get corresponding Loguru level if it exists + level: str | int try: level = loguru_logger.level(record.levelname).name except ValueError: @@ -21,7 +23,7 @@ def emit(self, record): loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) -def init_logging(): +def init_logging() -> Logger: loguru_logger.remove() fmt = ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <7} | {module: <11}:" diff --git a/src/eduid/vccs/server/password.py b/src/eduid/vccs/server/password.py index bbc8c20b0..c1ece43dd 100644 --- a/src/eduid/vccs/server/password.py +++ b/src/eduid/vccs/server/password.py @@ -10,7 +10,7 @@ async def authenticate_password( cred: PasswordCredential, factor: RequestFactor, user_id: str, hasher: VCCSYHSMHasher, kdf: NDNKDF -): +) -> bool: res = False H2 = await calculate_cred_hash(user_id=user_id, H1=factor.H1, cred=cred, hasher=hasher, kdf=kdf) # XXX need to log successful login in credential_store to be able to ban diff --git a/src/eduid/vccs/server/run.py b/src/eduid/vccs/server/run.py index 64c3bb935..6fa32a892 100644 --- a/src/eduid/vccs/server/run.py +++ b/src/eduid/vccs/server/run.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from typing import Any -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from ndnkdf import ndnkdf from starlette.responses import JSONResponse @@ -51,7 +51,7 @@ def __init__(self, test_config: Mapping[str, Any] | None = None) -> None: @app.on_event("startup") -async def startup_event(): +async def startup_event() -> None: """ Uvicorn mucks with the logging config on startup, particularly the access log. Rein it in. """ @@ -75,7 +75,7 @@ async def startup_event(): @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: request.app.logger.warning(f"Failed parsing request: {exc}") return JSONResponse({"errors": exc.errors()}, status_code=HTTP_422_UNPROCESSABLE_ENTITY) diff --git a/src/eduid/vccs/server/tests/test_db.py b/src/eduid/vccs/server/tests/test_db.py index bd7493863..120cbeefe 100644 --- a/src/eduid/vccs/server/tests/test_db.py +++ b/src/eduid/vccs/server/tests/test_db.py @@ -6,7 +6,7 @@ class TestCredential(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.data = { "_id": ObjectId("54042b7a9b3f2299bb9d5546"), "credential": { @@ -24,11 +24,11 @@ def setUp(self): "revision": 1, } - def test_from_dict(self): + def test_from_dict(self) -> None: cred = PasswordCredential.from_dict(self.data) assert cred.key_handle == 8192 - def test_to_dict_from_dict(self): + def test_to_dict_from_dict(self) -> None: cred1 = PasswordCredential.from_dict(self.data) cred2 = PasswordCredential.from_dict(cred1.to_dict()) assert cred1.to_dict() == cred2.to_dict() diff --git a/src/eduid/webapp/authn/app.py b/src/eduid/webapp/authn/app.py index eb27a3fd5..ef344552b 100644 --- a/src/eduid/webapp/authn/app.py +++ b/src/eduid/webapp/authn/app.py @@ -10,7 +10,7 @@ class AuthnApp(EduIDBaseApp): - def __init__(self, config: AuthnConfig, **kwargs): + def __init__(self, config: AuthnConfig, **kwargs: Any) -> None: super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/authn/tests/test_authn.py b/src/eduid/webapp/authn/tests/test_authn.py index 3074a4ee1..ba405b76a 100644 --- a/src/eduid/webapp/authn/tests/test_authn.py +++ b/src/eduid/webapp/authn/tests/test_authn.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from typing import Any -from flask import Blueprint +from flask import Blueprint, Response +from flask.typing import ResponseReturnValue from saml2.s_utils import deflate_and_base64_encode from werkzeug.exceptions import NotFound from werkzeug.http import dump_cookie @@ -15,6 +16,7 @@ from eduid.common.config.parsers import load_config from eduid.common.misc.timeutil import utc_now from eduid.common.models.saml2 import EduidAuthnContextClass +from eduid.userdb.testing import SetupConfig from eduid.webapp.authn.app import AuthnApp, authn_init_app from eduid.webapp.authn.settings.common import AuthnConfig from eduid.webapp.common.api.testing import EduidAPITestCase @@ -41,14 +43,8 @@ class AuthnAPITestBase(EduidAPITestCase): app: AuthnApp - def setUp( # type: ignore[override] - self, - *args: list[Any], - users: list[str] | None = None, - copy_user_to_private: bool = False, - **kwargs: dict[str, Any], - ) -> None: - super().setUp(*args, users=users, copy_user_to_private=copy_user_to_private, **kwargs) + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) self.idp_url = "https://idp.example.com/simplesaml/saml2/idp/SSOService.php" def update_config(self, config: dict[str, Any]) -> dict[str, Any]: @@ -251,22 +247,27 @@ class AuthnAPITestCase(AuthnAPITestBase): app: AuthnApp - def setUp(self, **kwargs): - super().setUp(users=["hubba-bubba", "hubba-fooo"], **kwargs) + def setUp(self, config: SetupConfig | None = None) -> None: + if config is None: + config = SetupConfig() + config.users = ["hubba-bubba", "hubba-fooo"] + super().setUp(config=config) - def test_login_authn(self): + def test_login_authn(self) -> None: self.authn("/authenticate", FrontendAction.LOGIN) - def test_chpass_authn(self): + def test_chpass_authn(self) -> None: self.authn("/authenticate", FrontendAction.CHANGE_PW_AUTHN) - def test_terminate_authn(self): + def test_terminate_authn(self) -> None: self.authn("/authenticate", FrontendAction.TERMINATE_ACCOUNT_AUTHN) - def test_login_assertion_consumer_service(self): - for accr in EduidAuthnContextClass: - if accr == EduidAuthnContextClass.NOT_IMPLEMENTED: + def test_login_assertion_consumer_service(self) -> None: + for context_class in EduidAuthnContextClass: + if context_class == EduidAuthnContextClass.NOT_IMPLEMENTED: accr = None + else: + accr = context_class eppn = "hubba-bubba" res = self.acs("/authenticate", eppn, frontend_action=FrontendAction.LOGIN, accr=accr) assert res.session.common.eppn == "hubba-bubba" @@ -275,7 +276,7 @@ def test_login_assertion_consumer_service(self): if accr: assert authn.asserted_authn_ctx == accr.value - def test_assertion_consumer_service(self): + def test_assertion_consumer_service(self) -> None: actions = [FrontendAction.LOGIN, FrontendAction.CHANGE_PW_AUTHN, FrontendAction.TERMINATE_ACCOUNT_AUTHN] for action in actions: res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=action) @@ -287,11 +288,12 @@ def test_assertion_consumer_service(self): age = utc_now() - authn.authn_instant assert 10 < age.total_seconds() < 15 - def test_frontend_state(self): + def test_frontend_state(self) -> None: eppn = "hubba-bubba" self.acs("/authenticate", eppn, FrontendAction.REMOVE_SECURITY_KEY_AUTHN, frontend_state="key_id_to_remove") - def _signup_authn_user(self, eppn): + # TODO: up for removal since it seems unused + def _signup_authn_user(self, eppn: str) -> ResponseReturnValue: timestamp = utc_now() with self.app.test_client() as c: @@ -308,7 +310,7 @@ def _signup_authn_user(self, eppn): class AuthnTestApp(AuthnBaseApp): - def __init__(self, config: AuthnConfig, **kwargs): + def __init__(self, config: AuthnConfig, **kwargs: Any) -> None: super().__init__(config, **kwargs) self.conf = config @@ -342,12 +344,12 @@ def load_app(self, test_config: Mapping[str, Any]) -> AuthnTestApp: config = load_config(typ=AuthnConfig, app_name="testing", ns="webapp", test_config=test_config) return AuthnTestApp(config) - def test_no_cookie(self): + def test_no_cookie(self) -> None: with self.app.test_client() as c: resp = c.get("/") self.assertEqual(resp.status_code, 401) - def test_cookie(self): + def test_cookie(self) -> None: sessid = "fb1f42420b0109020203325d750185673df252de388932a3957f522a6c43aa47" self.redis_instance.conn.set(sessid, json.dumps({"v1": {"id": "0"}})) @@ -360,20 +362,20 @@ class NoAuthnAPITestCase(EduidAPITestCase): app: AuthnTestApp - def setUp(self): - super().setUp() + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) test_views = Blueprint("testing", __name__) @test_views.route("/test") - def test(): + def test() -> str: return "OK" @test_views.route("/test2") - def test2(): + def test2() -> str: return "OK" @test_views.route("/test3") - def test3(): + def test3() -> str: return "OK" self.app.register_blueprint(test_views) @@ -403,17 +405,17 @@ def load_app(self, test_config: Mapping[str, Any]) -> AuthnTestApp: config = load_config(typ=AuthnConfig, app_name="testing", ns="webapp", test_config=test_config) return AuthnTestApp(config) - def test_no_authn(self): + def test_no_authn(self) -> None: with self.app.test_client() as c: resp = c.get("/test") self.assertEqual(resp.status_code, 200) - def test_authn(self): + def test_authn(self) -> None: with self.app.test_client() as c: resp = c.get("/test2") self.assertEqual(resp.status_code, 401) - def test_no_authn_util(self): + def test_no_authn_util(self) -> None: no_authn_urls_before = [path for path in self.app.conf.no_authn_urls] no_authn_path = "/test3" no_authn_views(self.app.conf, [no_authn_path]) @@ -425,33 +427,35 @@ def test_no_authn_util(self): class LogoutRequestTests(AuthnAPITestBase): - def test_metadataview(self): + def test_metadataview(self) -> None: with self.app.test_client() as c: response = c.get("/saml2-metadata") self.assertEqual(response.status, "200 OK") - def test_logout_nologgedin(self): + def test_logout_nologgedin(self) -> None: eppn = "hubba-bubba" with self.app.test_request_context("/logout", method="GET"): # eppn is set in the IdP session.common.eppn = eppn response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn(self.app.conf.saml2_logout_redirect_url, response.headers["Location"]) - def test_logout_loggedin(self): + def test_logout_loggedin(self) -> None: res = self.acs(url="/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(res.session.meta.cookie_val) with self.app.test_request_context("/logout", method="GET", headers={"Cookie": cookie}): response = self.app.dispatch_request() + assert isinstance(response, Response) logger.debug(f"Test called /logout, response {response}") self.assertEqual(response.status, "302 FOUND") self.assertIn( "https://idp.example.com/simplesaml/saml2/idp/SingleLogoutService.php", response.headers["location"] ) - def test_logout_service_startingSP(self): + def test_logout_service_startingSP(self) -> None: session_id = self.start_authenticate(eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(session_id) @@ -465,11 +469,12 @@ def test_logout_service_startingSP(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) - def test_logout_service_startingSP_already_logout(self): + def test_logout_service_startingSP_already_logout(self) -> None: session_id = self.start_authenticate(eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) with self.app.test_request_context( @@ -481,11 +486,12 @@ def test_logout_service_startingSP_already_logout(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) - def test_logout_service_startingIDP(self): + def test_logout_service_startingIDP(self) -> None: res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(res.session.meta.cookie_val) @@ -499,6 +505,7 @@ def test_logout_service_startingIDP(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") assert ( @@ -506,7 +513,7 @@ def test_logout_service_startingIDP(self): in response.location ) - def test_logout_service_startingIDP_no_subject_id(self): + def test_logout_service_startingIDP_no_subject_id(self) -> None: eppn = "hubba-bubba" res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) session_id = res.session.meta.cookie_val @@ -539,6 +546,7 @@ def test_logout_service_startingIDP_no_subject_id(self): session.authn.name_id = None session.persist() # Explicit session.persist is needed when working within a test_request_context response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) diff --git a/src/eduid/webapp/bankid/app.py b/src/eduid/webapp/bankid/app.py index f03ca21c1..bc252036e 100644 --- a/src/eduid/webapp/bankid/app.py +++ b/src/eduid/webapp/bankid/app.py @@ -15,7 +15,7 @@ class BankIDApp(AuthnBaseApp): - def __init__(self, config: BankIDConfig, **kwargs: Any): + def __init__(self, config: BankIDConfig, **kwargs: Any) -> None: super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/bankid/tests/test_app.py b/src/eduid/webapp/bankid/tests/test_app.py index a4a7f7981..f59736aca 100644 --- a/src/eduid/webapp/bankid/tests/test_app.py +++ b/src/eduid/webapp/bankid/tests/test_app.py @@ -13,6 +13,7 @@ from eduid.userdb.credentials.external import BankIDCredential, SwedenConnectCredential from eduid.userdb.element import ElementKey from eduid.userdb.identity import IdentityProofingMethod +from eduid.userdb.testing import SetupConfig from eduid.webapp.bankid.app import BankIDApp, init_bankid_app from eduid.webapp.bankid.helpers import BankIDMsg from eduid.webapp.common.api.messages import AuthnStatusMsg, TranslatableMsg @@ -34,7 +35,7 @@ class BankIDTests(ProofingTests[BankIDApp]): """Base TestCase for those tests that need a full environment setup""" - def setUp(self, *args: Any, **kwargs: Any) -> None: + def setUp(self, config: SetupConfig | None = None) -> None: self.test_user_eppn = "hubba-bubba" self.test_unverified_user_eppn = "hubba-baar" self.test_user_nin = NinIdentity( @@ -161,7 +162,10 @@ def setUp(self, *args: Any, **kwargs: Any) -> None: """ # noqa: E501 - super().setUp(users=["hubba-bubba", "hubba-baar"]) + if config is None: + config = SetupConfig() + config.users = ["hubba-bubba", "hubba-baar"] + super().setUp(config=config) def load_app(self, config: Mapping[str, Any]) -> BankIDApp: """ @@ -463,7 +467,7 @@ def _call_endpoint_and_saml_acs( method=method, ) - def test_authenticate(self): + def test_authenticate(self) -> None: response = self.browser.get("/") self.assertEqual(response.status_code, 401) with self.session_cookie(self.browser, self.test_user.eppn) as browser: @@ -471,7 +475,7 @@ def test_authenticate(self): self._check_success_response(response, type_="GET_BANKID_SUCCESS") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_u2f_token_verify(self, mock_request_user_sync: MagicMock): + def test_u2f_token_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -491,7 +495,7 @@ def test_u2f_token_verify(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, token_verified=True, num_proofings=1) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock): + def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -511,7 +515,7 @@ def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, token_verified=True, num_proofings=1) - def test_mfa_token_verify_wrong_verified_nin(self): + def test_mfa_token_verify_wrong_verified_nin(self) -> None: eppn = self.test_user.eppn nin = self.test_user_wrong_nin credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -532,7 +536,7 @@ def test_mfa_token_verify_wrong_verified_nin(self): self._verify_user_parameters(eppn, identity=nin, identity_present=False) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMock): + def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -556,7 +560,7 @@ def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMoc eppn, token_verified=True, num_proofings=2, identity_present=True, identity=nin, identity_verified=True ) - def test_mfa_token_verify_no_mfa_login(self): + def test_mfa_token_verify_no_mfa_login(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -581,7 +585,7 @@ def test_mfa_token_verify_no_mfa_login(self): ) self._verify_user_parameters(eppn) - def test_mfa_token_verify_no_mfa_token_in_session(self): + def test_mfa_token_verify_no_mfa_token_in_session(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "webauthn") @@ -600,7 +604,7 @@ def test_mfa_token_verify_no_mfa_token_in_session(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_aborted_auth(self): + def test_mfa_token_verify_aborted_auth(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -619,7 +623,7 @@ def test_mfa_token_verify_aborted_auth(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_cancel_auth(self): + def test_mfa_token_verify_cancel_auth(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "webauthn") @@ -640,7 +644,7 @@ def test_mfa_token_verify_cancel_auth(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_auth_fail(self): + def test_mfa_token_verify_auth_fail(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -663,7 +667,7 @@ def test_mfa_token_verify_auth_fail(self): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock): + def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -688,7 +692,7 @@ def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock) self._verify_user_parameters(eppn, identity=nin, identity_verified=True, token_verified=True, num_proofings=2) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify(self, mock_request_user_sync: MagicMock): + def test_nin_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -719,7 +723,7 @@ def test_nin_verify(self, mock_request_user_sync: MagicMock): assert doc["surname"] == "Älm" @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login(self, mock_request_user_sync: MagicMock): + def test_mfa_login(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -735,7 +739,7 @@ def test_mfa_login(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=True, num_proofings=0) - def test_mfa_login_no_nin(self): + def test_mfa_login_no_nin(self) -> None: eppn = self.test_unverified_user_eppn self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=False, token_verified=False) @@ -751,7 +755,7 @@ def test_mfa_login_no_nin(self): self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=False, num_proofings=0) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock): + def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -777,7 +781,7 @@ def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock): + def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -805,7 +809,7 @@ def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_backdoor(self, mock_request_user_sync: Any): + def test_nin_verify_backdoor(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -828,7 +832,7 @@ def test_nin_verify_backdoor(self, mock_request_user_sync: Any): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock): + def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -856,7 +860,7 @@ def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: MagicMock): + def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -883,7 +887,7 @@ def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: Magi eppn, identity=self.test_user_nin, num_mfa_tokens=0, num_proofings=1, identity_verified=True ) - def test_nin_verify_already_verified(self): + def test_nin_verify_already_verified(self) -> None: # Verify that the test user has a verified NIN in the database already eppn = self.test_user.eppn nin = self.test_user_nin @@ -902,7 +906,7 @@ def test_nin_verify_already_verified(self): ) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMock): + def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync user = self.app.central_userdb.get_user_by_eppn(self.test_user.eppn) @@ -927,7 +931,7 @@ def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMoc cred = _creds[0] assert cred.level in self.app.conf.bankid_required_loa - def test_mfa_authentication_too_old_authn_instant(self): + def test_mfa_authentication_too_old_authn_instant(self) -> None: self.reauthn( endpoint="/mfa-authenticate", frontend_action=FrontendAction.LOGIN_MFA_AUTHN, @@ -936,7 +940,7 @@ def test_mfa_authentication_too_old_authn_instant(self): expect_error=True, ) - def test_mfa_authentication_wrong_nin(self): + def test_mfa_authentication_wrong_nin(self) -> None: user = self.app.central_userdb.get_user_by_eppn(self.test_user_eppn) assert user.identities.nin is not None assert user.identities.nin.is_verified is True, "User was expected to have a verified NIN" diff --git a/src/eduid/webapp/common/api/app.py b/src/eduid/webapp/common/api/app.py index 6e361ffa3..74101906f 100644 --- a/src/eduid/webapp/common/api/app.py +++ b/src/eduid/webapp/common/api/app.py @@ -40,6 +40,7 @@ if TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication + from werkzeug.middleware.profiler import ProfilerMiddleware DEBUG = os.environ.get("EDUID_APP_DEBUG", False) if DEBUG: @@ -61,7 +62,7 @@ def __init__( init_central_userdb: bool = True, handle_exceptions: bool = True, **kwargs: Any, - ): + ) -> None: """ :param config: EduID Flask app configuration subclass :param init_central_userdb: Whether the app requires access to the central user db. @@ -192,7 +193,7 @@ def init_status_views(app: EduIDBaseApp, config: EduIDBaseAppConfig) -> None: return None -def init_app_profiling(app: WSGIApplication, config: EduIDBaseAppConfig): +def init_app_profiling(app: WSGIApplication, config: EduIDBaseAppConfig) -> ProfilerMiddleware: """ Setup profiling middleware for any app. """ diff --git a/src/eduid/webapp/common/api/captcha.py b/src/eduid/webapp/common/api/captcha.py index 09765614f..16b8e1e18 100644 --- a/src/eduid/webapp/common/api/captcha.py +++ b/src/eduid/webapp/common/api/captcha.py @@ -11,7 +11,7 @@ class InternalCaptcha: - def __init__(self, config: CaptchaConfigMixin): + def __init__(self, config: CaptchaConfigMixin) -> None: self.image_generator = ImageCaptcha( height=config.captcha_height, width=config.captcha_width, diff --git a/src/eduid/webapp/common/api/checks.py b/src/eduid/webapp/common/api/checks.py index 14d61a71d..5af238187 100644 --- a/src/eduid/webapp/common/api/checks.py +++ b/src/eduid/webapp/common/api/checks.py @@ -47,7 +47,7 @@ class FailCountItem: exit_at: datetime | None = None count: int = 0 - def __str__(self): + def __str__(self) -> str: return f"(first_failure: {self.first_failure.isoformat()}, fail count: {self.count})" diff --git a/src/eduid/webapp/common/api/debug.py b/src/eduid/webapp/common/api/debug.py index a881e7d88..c8333b01f 100644 --- a/src/eduid/webapp/common/api/debug.py +++ b/src/eduid/webapp/common/api/debug.py @@ -1,32 +1,35 @@ import pprint import sys import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable from dataclasses import asdict from typing import Any from urllib import parse from flask import Flask, url_for +# TODO: in python >= 3.11 import from wsgiref.types +from eduid.webapp.common.wsgi import StartResponse, WSGIEnvironment + __author__ = "lundberg" class LoggingMiddleware: - def __init__(self, app: Callable[..., Any]): + def __init__(self, app: Callable[..., Any]) -> None: self._app = app - def __call__(self, environ: dict[Any, Any], resp: Callable[..., Any]): + def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> Iterable[bytes]: errorlog = environ["wsgi.errors"] pprint.pprint(("REQUEST", environ), stream=errorlog) - def log_response(status, headers, *args): + def log_response(status: str, headers: list[tuple[str, str]], *args: Any) -> Callable[[bytes], object]: pprint.pprint(("RESPONSE", status, headers), stream=errorlog) - return resp(status, headers, *args) + return start_response(status, headers, *args) return self._app(environ, log_response) -def log_endpoints(app: Flask): +def log_endpoints(app: Flask) -> None: output: list[str] = [] with app.app_context(): for rule in app.url_map.iter_rules(): @@ -43,7 +46,7 @@ def log_endpoints(app: Flask): pprint.pprint(("ENDPOINT", line), stream=sys.stderr) -def dump_config(app: Flask): +def dump_config(app: Flask) -> None: pprint.pprint(("CONFIGURATION", "app.config"), stream=sys.stderr) try: config_items = asdict(app.config).items() # type: ignore[call-overload] @@ -54,8 +57,8 @@ def dump_config(app: Flask): pprint.pprint((key, value), stream=sys.stderr) -def init_app_debug(app: Flask): - app.wsgi_app = LoggingMiddleware(app.wsgi_app) # type: ignore[assignment] +def init_app_debug(app: Flask) -> Flask: + app.wsgi_app = LoggingMiddleware(app.wsgi_app) # type: ignore[method-assign] dump_config(app) log_endpoints(app) pprint.pprint(("view_functions", app.view_functions), stream=sys.stderr) diff --git a/src/eduid/webapp/common/api/decorators.py b/src/eduid/webapp/common/api/decorators.py index f3840a631..60a97d7cb 100644 --- a/src/eduid/webapp/common/api/decorators.py +++ b/src/eduid/webapp/common/api/decorators.py @@ -135,10 +135,10 @@ class MarshalWith: on-the-wire format of these Flux Standard Actions. """ - def __init__(self, schema: type[Schema]): + def __init__(self, schema: type[Schema]) -> None: self.schema = schema - def __call__(self, f: EduidRouteCallable): + def __call__(self, f: EduidRouteCallable) -> Callable: @wraps(f) def marshal_decorator(*args: Any, **kwargs: Any) -> WerkzeugResponse: # Call the Flask view, which is expected to return a FluxData instance, @@ -184,10 +184,10 @@ class UnmarshalWith: not a FluxData instance. """ - def __init__(self, schema: type[Schema]): + def __init__(self, schema: type[Schema]) -> None: self.schema = schema - def __call__(self, f: EduidRouteCallable): + def __call__(self, f: EduidRouteCallable) -> Callable: @wraps(f) def unmarshal_decorator( *args: Any, **kwargs: Any diff --git a/src/eduid/webapp/common/api/exceptions.py b/src/eduid/webapp/common/api/exceptions.py index 5e0d11b2a..07321ba43 100644 --- a/src/eduid/webapp/common/api/exceptions.py +++ b/src/eduid/webapp/common/api/exceptions.py @@ -1,7 +1,8 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from flask import jsonify +from flask import Flask, Response, jsonify +from werkzeug.exceptions import HTTPException __author__ = "lundberg" @@ -17,7 +18,7 @@ def __init__( message: str = "ApiException", status_code: int | None = None, payload: Mapping[str, Any] | None = None, - ): + ) -> None: """ :param message: Error message :param status_code: Http status code @@ -29,13 +30,13 @@ def __init__( self.status_code = status_code self.payload = payload - def __repr__(self): + def __repr__(self) -> str: return f"ApiException (message={self.message!s}, status_code={self.status_code!s}, payload={self.payload!r})" - def __unicode__(self): + def __unicode__(self) -> str: return self.__str__() - def __str__(self): + def __str__(self) -> str: if self.payload: return f"{self.status_code!s} with message {self.message!s} and payload {self.payload!r}" return f"{self.status_code!s} with message {self.message!s}" @@ -67,15 +68,15 @@ class ProofingLogFailure(Exception): class ThrottledException(Exception): state: "ResetPasswordEmailState" - def __init__(self, state: "ResetPasswordEmailState"): + def __init__(self, state: "ResetPasswordEmailState") -> None: Exception.__init__(self) self.state = state -def init_exception_handlers(app): +def init_exception_handlers(app: Flask) -> Flask: # Init error handler for raised exceptions @app.errorhandler(400) - def _handle_flask_http_exception(error): + def _handle_flask_http_exception(error: HTTPException) -> Response: app.logger.error(f"HttpException {error!s}") e = ApiException(error.name, error.code) if app.config.get("DEBUG"): @@ -87,7 +88,7 @@ def _handle_flask_http_exception(error): return app -def init_sentry(app): +def init_sentry(app: Flask) -> Flask: if app.config.get("SENTRY_DSN"): try: from raven.contrib.flask import Sentry diff --git a/src/eduid/webapp/common/api/helpers.py b/src/eduid/webapp/common/api/helpers.py index ca99e86ae..bc89963a6 100644 --- a/src/eduid/webapp/common/api/helpers.py +++ b/src/eduid/webapp/common/api/helpers.py @@ -291,7 +291,7 @@ def send_mail( app: EduIDBaseApp, context: dict[str, Any] | None = None, reference: str | None = None, -): +) -> None: """ :param subject: subject text :param to_addresses: email addresses for the to field diff --git a/src/eduid/webapp/common/api/messages.py b/src/eduid/webapp/common/api/messages.py index 156792f12..2e3133a7d 100644 --- a/src/eduid/webapp/common/api/messages.py +++ b/src/eduid/webapp/common/api/messages.py @@ -146,7 +146,7 @@ def _make_payload( return res -def make_query_string(msg: TranslatableMsg, error: bool = True): +def make_query_string(msg: TranslatableMsg, error: bool = True) -> str: """ Make a query string to send a translatable message to the front in the URL of a GET request. diff --git a/src/eduid/webapp/common/api/middleware.py b/src/eduid/webapp/common/api/middleware.py index 9929ae01c..5ffd124c4 100644 --- a/src/eduid/webapp/common/api/middleware.py +++ b/src/eduid/webapp/common/api/middleware.py @@ -1,9 +1,15 @@ __author__ = "lundberg" +from collections.abc import Callable, Iterable +from typing import Any + +# TODO: in python >= 3.11 import from wsgiref.types +from eduid.webapp.common.wsgi import StartResponse, WSGIEnvironment + # Copied from https://stackoverflow.com/questions/18967441/add-a-prefix-to-all-flask-routes/36033627#36033627 class PrefixMiddleware: - def __init__(self, app, prefix="", server_name=""): + def __init__(self, app: Callable[..., Any], prefix: str = "", server_name: str = "") -> None: self.app = app if prefix is None: prefix = "" @@ -12,7 +18,7 @@ def __init__(self, app, prefix="", server_name=""): self.prefix = prefix self.server_name = server_name - def __call__(self, environ, start_response): + def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> Iterable[bytes]: # Handle localhost requests for health checks if environ.get("REMOTE_ADDR") == "127.0.0.1": environ["HTTP_HOST"] = self.server_name diff --git a/src/eduid/webapp/common/api/request.py b/src/eduid/webapp/common/api/request.py index 719e4c55e..41964f6c4 100644 --- a/src/eduid/webapp/common/api/request.py +++ b/src/eduid/webapp/common/api/request.py @@ -15,7 +15,8 @@ """ import logging -from typing import Any, AnyStr +from collections.abc import Callable, Iterator +from typing import Any, AnyStr, TypeVar from flask import abort from flask.wrappers import Request as FlaskRequest @@ -37,7 +38,7 @@ def sanitize_input( untrusted_text: AnyStr, content_type: str | None = None, strip_characters: bool = False, - ): + ) -> str: try: return super().sanitize_input( untrusted_text=untrusted_text, content_type=content_type, strip_characters=strip_characters @@ -53,7 +54,7 @@ class SanitizedImmutableMultiDict(ImmutableMultiDict, SanitationMixin): sanitize the extracted data. """ - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: """ Return the first data value for this key; raises KeyError if not found. @@ -64,7 +65,7 @@ def __getitem__(self, key): value = super().__getitem__(key) return self.sanitize_input(value) - def getlist(self, key, type=None): + def getlist(self, key: str, type: Callable[[Any], Any] | None = None) -> list: """ Return the list of items for a given key. If that key is not in the `MultiDict`, the return value will be an empty list. Just as `get` @@ -77,10 +78,11 @@ def getlist(self, key, type=None): by this callable the value will be removed from the list. :return: a :class:`list` of all the values for the key. """ + assert type is not None value_list = super().getlist(key, type=type) return [self.sanitize_input(v) for v in value_list] - def items(self, multi=False): + def items(self, multi: bool = False) -> Iterator[tuple[Any, str]]: # type: ignore[override] """ Return an iterator of ``(key, value)`` pairs. @@ -97,7 +99,7 @@ def items(self, multi=False): else: yield key, values[0] - def lists(self): + def lists(self) -> Iterator[tuple[Any, list[str]]]: """Return a list of ``(key, values)`` pairs, where values is the list of all values associated with the key.""" @@ -105,14 +107,14 @@ def lists(self): values = [self.sanitize_input(v) for v in values] yield key, values - def values(self): + def values(self) -> Iterator[str]: # type: ignore[override] """ Returns an iterator of the first value on every key's value list. """ for values in dict.values(self): yield self.sanitize_input(values[0]) - def listvalues(self): + def listvalues(self) -> Iterator[Iterator[Any]]: # type: ignore[override] """ Return an iterator of all values associated with a key. Zipping :meth:`keys` and this is the same as calling :meth:`lists`: @@ -124,7 +126,7 @@ def listvalues(self): for values in dict.values(self): yield (self.sanitize_input(v) for v in values) - def to_dict(self, flat=True): + def to_dict(self, flat: bool = True) -> dict | dict[Any, list[str]]: """Return the contents as regular dict. If `flat` is `True` the returned dict will only have the first item present, if `flat` is `False` all values will be returned as lists. @@ -143,6 +145,9 @@ def to_dict(self, flat=True): return dict(self.lists()) +T = TypeVar("T") + + class SanitizedTypeConversionDict(ImmutableTypeConversionDict, SanitationMixin): """ See `werkzeug.datastructures.TypeConversionDict`. @@ -150,14 +155,14 @@ class SanitizedTypeConversionDict(ImmutableTypeConversionDict, SanitationMixin): sanitize the extracted data. """ - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: """ Sanitized __getitem__ """ val = super(ImmutableTypeConversionDict, self).__getitem__(key) return self.sanitize_input(str(val)) - def get(self, key, default=None, type=None) -> Any | None: # type: ignore[override] + def get(self, key: str, default: str | None = None, type: Callable[[Any], T] | None = None) -> str | T | None: # type: ignore[override] """ Sanitized, type conversion get. The value identified by `key` is sanitized, and if `type` @@ -173,20 +178,20 @@ def get(self, key, default=None, type=None) -> Any | None: # type: ignore[overr :rtype: object """ try: - val = self.sanitize_input(self[key]) + val: Any = self.sanitize_input(self[key]) if type is not None: val = type(val) except (KeyError, ValueError): val = default return val - def values(self): + def values(self) -> list[str]: # type: ignore[override] """ sanitized values """ return [self.sanitize_input(v) for v in super(ImmutableTypeConversionDict, self).values()] - def items(self): + def items(self) -> list[tuple[str, str]]: # type: ignore[override] """ Sanitized items """ @@ -198,7 +203,7 @@ class SanitizedEnvironHeaders(EnvironHeaders, SanitationMixin): Sanitized and read only version of the headers from a WSGI environment. """ - def __init__(self, environ: dict[str, Any]): + def __init__(self, environ: dict[str, Any]) -> None: # set content type from environ at init so we don't get in to an infinite recursion # when sanitize_input tries to look it up later self.content_type = environ.get("CONTENT_TYPE") @@ -215,7 +220,7 @@ def __getitem__(self, key: str, _get_mode: bool = False) -> str: # type: ignore val = super().__getitem__(key) return self.sanitize_input(untrusted_text=val, content_type=self.content_type) - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, str]]: # type: ignore[override] """ Sanitized __iter__ """ @@ -231,11 +236,11 @@ class Request(FlaskRequest, SanitationMixin): parameter_storage_class = SanitizedImmutableMultiDict dict_storage_class = SanitizedTypeConversionDict # type: ignore[assignment] - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.headers = SanitizedEnvironHeaders(environ=self.environ) - def get_data(self, *args: Any, **kwargs: Any): + def get_data(self, *args: Any, **kwargs: Any) -> str: # type: ignore[override] text = super().get_data(*args, **kwargs) if text: text = self.sanitize_input(untrusted_text=text, content_type=self.mimetype) diff --git a/src/eduid/webapp/common/api/schemas/csrf.py b/src/eduid/webapp/common/api/schemas/csrf.py index 59452a5e0..524abdfb0 100644 --- a/src/eduid/webapp/common/api/schemas/csrf.py +++ b/src/eduid/webapp/common/api/schemas/csrf.py @@ -1,4 +1,5 @@ import logging +from typing import Any from flask import current_app, request from marshmallow import Schema, ValidationError, fields, post_load, pre_dump, validates @@ -15,7 +16,7 @@ class CSRFRequestMixin(Schema): csrf_token = fields.String(required=True) @validates("csrf_token") - def validate_csrf_token(self, value, **kwargs): + def validate_csrf_token(self, value: str, **kwargs: Any) -> None: custom_header = request.headers.get("X-Requested-With") if custom_header != "XMLHttpRequest": # TODO: move value to config current_app.logger.error("CSRF check: missing custom X-Requested-With header") @@ -25,13 +26,13 @@ def validate_csrf_token(self, value, **kwargs): logger.debug(f"Validated CSRF token in session: {session.get_csrf_token()}") @post_load - def post_processing(self, in_data, **kwargs): + def post_processing(self, in_data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: # Remove token from data forwarded to views in_data = self.remove_csrf_token(in_data) return in_data @staticmethod - def remove_csrf_token(in_data, **kwargs): + def remove_csrf_token(in_data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: del in_data["csrf_token"] return in_data @@ -40,7 +41,7 @@ class CSRFResponseMixin(Schema): csrf_token = fields.String(required=True) @pre_dump - def get_csrf_token(self, out_data, **kwargs): + def get_csrf_token(self, out_data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: # Generate a new csrf token for every response out_data["csrf_token"] = session.new_csrf_token() logger.debug(f'Generated new CSRF token in CSRFResponseMixin: {out_data["csrf_token"]}') diff --git a/src/eduid/webapp/common/api/schemas/email.py b/src/eduid/webapp/common/api/schemas/email.py index 5ef0ba85c..ca5610b4b 100644 --- a/src/eduid/webapp/common/api/schemas/email.py +++ b/src/eduid/webapp/common/api/schemas/email.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from marshmallow.fields import Email __author__ = "lundberg" @@ -8,10 +11,12 @@ class LowercaseEmail(Email): Email field that serializes and deserializes to a lower case string. """ - def _serialize(self, value, attr, obj, **kwargs): - value = super()._serialize(value, attr, obj, **kwargs) - return value.lower() + def _serialize(self, value: str | bytes, attr: str | None, obj: object, **kwargs: Any) -> str | None: + _value = super()._serialize(value, attr, obj, **kwargs) + if _value is None: + return None + return _value.lower() - def _deserialize(self, value, attr, data, **kwargs): - value = super()._deserialize(value, attr, data, **kwargs) - return value.lower() + def _deserialize(self, value: str | bytes, attr: str | None, data: Mapping[str, Any] | None, **kwargs: Any) -> str: + _value: str = super()._deserialize(value, attr, data, **kwargs) + return _value.lower() diff --git a/src/eduid/webapp/common/api/schemas/models.py b/src/eduid/webapp/common/api/schemas/models.py index ca083eea8..a68b21042 100644 --- a/src/eduid/webapp/common/api/schemas/models.py +++ b/src/eduid/webapp/common/api/schemas/models.py @@ -1,6 +1,9 @@ +from collections.abc import Mapping from enum import Enum, unique from typing import Any +from flask import Request + from eduid.webapp.common.api.utils import get_flux_type __author__ = "lundberg" @@ -38,7 +41,13 @@ class FluxResponse: An action MUST NOT include properties other than type, payload, error, and meta. """ - def __init__(self, req, payload=None, error=None, meta=None): + def __init__( + self, + req: Request, + payload: Mapping[str, Any] | None = None, + error: bool | None = None, + meta: Mapping[str, Any] | None = None, + ) -> None: _suffix = "success" if error: _suffix = "fail" @@ -47,17 +56,17 @@ def __init__(self, req, payload=None, error=None, meta=None): self.meta = meta self.error = error - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__!s} ({self.to_dict()!r})>" - def __unicode__(self): + def __unicode__(self) -> str: return self.__str__() - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__!s} ({self.to_dict()!r})" def to_dict(self) -> dict[str, Any]: - rv = dict() + rv = dict[str, Any]() # A Flux Standard Action MUST have a type rv["type"] = self.flux_type # ... and MAY have payload, error, meta (and MUST NOT have anything else) @@ -75,10 +84,10 @@ def to_dict(self) -> dict[str, Any]: class FluxSuccessResponse(FluxResponse): - def __init__(self, req, payload, meta=None): + def __init__(self, req: Request, payload: Mapping[str, Any] | None, meta: Mapping[str, Any] | None = None) -> None: super().__init__(req, payload, meta=meta) class FluxFailResponse(FluxResponse): - def __init__(self, req, payload, meta=None): + def __init__(self, req: Request, payload: Mapping[str, Any] | None, meta: Mapping[str, Any] | None = None) -> None: super().__init__(req, payload, error=True, meta=meta) diff --git a/src/eduid/webapp/common/api/schemas/password.py b/src/eduid/webapp/common/api/schemas/password.py index 3faf75f49..74f1a6c02 100644 --- a/src/eduid/webapp/common/api/schemas/password.py +++ b/src/eduid/webapp/common/api/schemas/password.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import Schema, ValidationError __author__ = "lundberg" @@ -11,13 +13,13 @@ class Meta: min_entropy: int | None = None min_score: int | None = None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.Meta.zxcvbn_terms = kwargs.pop("zxcvbn_terms", []) self.Meta.min_entropy = kwargs.pop("min_entropy") self.Meta.min_score = kwargs.pop("min_score") super().__init__(*args, **kwargs) - def validate_password(self, password: str, **kwargs): + def validate_password(self, password: str, **kwargs: Any) -> None: """ :param password: New password diff --git a/src/eduid/webapp/common/api/schemas/validators.py b/src/eduid/webapp/common/api/schemas/validators.py index ec33f5fef..fb36b6f5a 100644 --- a/src/eduid/webapp/common/api/schemas/validators.py +++ b/src/eduid/webapp/common/api/schemas/validators.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import ValidationError from eduid.webapp.common.api.validation import is_valid_email, is_valid_nin @@ -5,7 +7,7 @@ __author__ = "lundberg" -def validate_nin(nin, **kwargs): +def validate_nin(nin: str, **kwargs: Any) -> bool: """ :param nin: National Identity Number :type nin: string_types @@ -18,7 +20,7 @@ def validate_nin(nin, **kwargs): raise ValidationError("nin needs to be formatted as 18|19|20yymmddxxxx") -def validate_email(email, **kwargs): +def validate_email(email: str, **kwargs: Any) -> bool: """ :param email: E-mail address :type email: string_types diff --git a/src/eduid/webapp/common/api/testing.py b/src/eduid/webapp/common/api/testing.py index b9e484ba3..95b62aede 100644 --- a/src/eduid/webapp/common/api/testing.py +++ b/src/eduid/webapp/common/api/testing.py @@ -26,7 +26,7 @@ from eduid.userdb.fixtures.users import UserFixtures from eduid.userdb.logs.db import ProofingLog from eduid.userdb.proofing.state import NinProofingState -from eduid.userdb.testing import MongoTemporaryInstance +from eduid.userdb.testing import MongoTemporaryInstance, SetupConfig from eduid.userdb.userdb import UserDB from eduid.webapp.common.api.app import EduIDBaseApp from eduid.webapp.common.api.messages import AuthnStatusMsg, TranslatableMsg @@ -97,16 +97,12 @@ class EduidAPITestCase(CommonTestCase, Generic[TTestAppVar]): app: TTestAppVar browser: CSRFTestClient - def setUp( # type: ignore[override] - self, - *args: list[Any], - users: list[str] | None = None, - copy_user_to_private: bool = False, - **kwargs: dict[str, Any], - ) -> None: + def setUp(self, config: SetupConfig | None = None) -> None: + if config is None: + config = SetupConfig() # test users - if users is None: - users = ["hubba-bubba"] + if config.users is None: + config.users = ["hubba-bubba"] _users = UserFixtures() _standard_test_users = { @@ -116,14 +112,15 @@ def setUp( # type: ignore[override] } # Make a list of User object to be saved to the new temporary mongodb instance - am_users = [_standard_test_users[x] for x in users] + am_users = [_standard_test_users[x] for x in config.users] - super().setUp(am_users=am_users, *args, **kwargs) + config.am_users = am_users + super().setUp(config=config) self.user: User | None = None # Load the user from the database so that it can be saved there again in tests - _test_user = self.amdb.get_user_by_eppn(users[0]) + _test_user = self.amdb.get_user_by_eppn(config.users[0]) # Initialize some convenience variables on self based on the first user in `users' self.test_user = _test_user self.test_user_data = self.test_user.to_dict() @@ -132,8 +129,8 @@ def setUp( # type: ignore[override] # Set up Redis for shared sessions self.redis_instance = RedisTemporaryInstance.get_instance() # settings - config = deepcopy(TEST_CONFIG) - self.settings = self.update_config(config) + test_config = deepcopy(TEST_CONFIG) + self.settings: dict[str, Any] = self.update_config(test_config) self.settings["redis_config"] = RedisConfig(host="localhost", port=self.redis_instance.port) assert isinstance(self.tmp_db, MongoTemporaryInstance) # please mypy self.settings["mongo_uri"] = self.tmp_db.uri @@ -147,14 +144,14 @@ def setUp( # type: ignore[override] self.content_type_json = "application/json" self.test_domain = "test.localhost" - if copy_user_to_private: + if config.copy_user_to_private: data = self.test_user.to_dict() _private_userdb = getattr(self.app, "private_userdb") assert isinstance(_private_userdb, UserDB) logging.info(f"Copying test-user {self.test_user} to private_userdb {_private_userdb}") _private_userdb.save(_private_userdb.user_from_dict(data=data)) - def tearDown(self): + def tearDown(self) -> None: try: # Reset anything that looks like a BaseDB, for the next test class. for this in vars(self.app).values(): @@ -323,7 +320,7 @@ def set_authn_action( finish_url: str | None = None, force_mfa: bool = False, credentials_used: list[ElementKey] | None = None, - ): + ) -> None: if not finish_url: finish_url = "https://example.com/ext-return/{app_name}/{authn_id}" @@ -372,11 +369,11 @@ def add_security_key_to_user(self, eppn: str, keyhandle: str, token_type: str = return mfa_token @staticmethod - def _get_all_navet_data(): + def _get_all_navet_data() -> NavetData: return NavetData.model_validate(MessageSender.get_devel_all_navet_data()) @staticmethod - def _get_full_postal_address(): + def _get_full_postal_address() -> FullPostalAddress: return FullPostalAddress.model_validate(MessageSender.get_devel_postal_address()) def _check_must_authenticate_response( @@ -385,7 +382,7 @@ def _check_must_authenticate_response( type_: str | None, frontend_action: FrontendAction, authn_status: AuthnActionStatus, - ): + ) -> None: """Check that a call to the API failed in the authentication stage.""" meta = { "frontend_action": frontend_action.value, @@ -394,7 +391,7 @@ def _check_must_authenticate_response( payload = { "message": AuthnStatusMsg.must_authenticate.value, } - return self._check_api_response(response, status=200, type_=type_, payload=payload, meta=meta) + self._check_api_response(response, status=200, type_=type_, payload=payload, meta=meta) def _check_error_response( self, @@ -403,9 +400,9 @@ def _check_error_response( msg: TranslatableMsg | None = None, error: Mapping[str, Any] | None = None, payload: Mapping[str, Any] | None = None, - ): + ) -> None: """Check that a call to the API failed in the data validation stage.""" - return self._check_api_response(response, 200, type_=type_, message=msg, error=error, payload=payload) + self._check_api_response(response, 200, type_=type_, message=msg, error=error, payload=payload) def _check_success_response( self, @@ -413,13 +410,13 @@ def _check_success_response( type_: str | None, msg: TranslatableMsg | None = None, payload: Mapping[str, Any] | None = None, - ): + ) -> None: """ Check the message returned from an eduID webapp endpoint. """ if response.json and response.json.get("error") is True: assert False is True, f"FluxResponse has error set to True: {response.json}" - return self._check_api_response(response, 200, type_=type_, message=msg, payload=payload) + self._check_api_response(response, 200, type_=type_, message=msg, payload=payload) @staticmethod def get_response_payload(response: TestResponse) -> dict[str, Any]: @@ -445,7 +442,7 @@ def _check_api_response( payload: Mapping[str, Any] | None = None, assure_not_in_payload: Iterable[str] | None = None, meta: Mapping[str, Any] | None = None, - ): + ) -> None: """ Check data returned from an eduID webapp endpoint. @@ -475,7 +472,7 @@ def _check_api_response( :param payload: Data expected to be found in the 'payload' of the response """ - def _assure_not_in_dict(d: Mapping[str, Any], unwanted_key: str): + def _assure_not_in_dict(d: Mapping[str, Any], unwanted_key: str) -> None: assert unwanted_key not in d, f"Key {unwanted_key} should not be in payload, but it is: {payload}" v2: Mapping[str, Any] for v2 in d.values(): @@ -534,7 +531,7 @@ def _check_nin_verified_ok( proofing_state: NinProofingState, number: str | None = None, created_by: str | None = None, - ): + ) -> None: if number is None and (self.test_user is not None and self.test_user.identities.nin): number = self.test_user.identities.nin.number @@ -552,7 +549,7 @@ def _check_nin_verified_ok( assert isinstance(_log, ProofingLog) assert _log.db_count() == 1 - def _check_nin_not_verified(self, user: User, number: str | None = None, created_by: str | None = None): + def _check_nin_not_verified(self, user: User, number: str | None = None, created_by: str | None = None) -> None: if number is None and (self.test_user is not None and self.test_user.identities.nin): number = self.test_user.identities.nin.number diff --git a/src/eduid/webapp/common/api/tests/test_backdoor.py b/src/eduid/webapp/common/api/tests/test_backdoor.py index 45e6b20cf..92b35bcbe 100644 --- a/src/eduid/webapp/common/api/tests/test_backdoor.py +++ b/src/eduid/webapp/common/api/tests/test_backdoor.py @@ -5,6 +5,7 @@ from eduid.common.config.base import EduIDBaseAppConfig, EduidEnvironment, MagicCookieMixin from eduid.common.config.parsers import load_config +from eduid.userdb.testing import SetupConfig from eduid.webapp.common.api.app import EduIDBaseApp from eduid.webapp.common.api.helpers import check_magic_cookie from eduid.webapp.common.api.testing import EduidAPITestCase @@ -14,8 +15,9 @@ @test_views.route("/get-code", methods=["GET"]) -def get_code(): +def get_code() -> str: current_app.logger.info("Endpoint get_code called") + assert isinstance(current_app, BackdoorTestApp) try: if check_magic_cookie(current_app.conf): eppn = request.args.get("eppn") @@ -34,21 +36,15 @@ class BackdoorTestConfig(EduIDBaseAppConfig, MagicCookieMixin): class BackdoorTestApp(EduIDBaseApp): - def __init__(self, config: BackdoorTestConfig): + def __init__(self, config: BackdoorTestConfig) -> None: super().__init__(config) self.conf = config class BackdoorTests(EduidAPITestCase[BackdoorTestApp]): - def setUp( # type: ignore[override] - self, - *args: list[Any], - users: list[str] | None = None, - copy_user_to_private: bool = False, - **kwargs: dict[str, Any], - ) -> None: - super().setUp(*args, users=users, copy_user_to_private=copy_user_to_private, **kwargs) + def setUp(self, config: SetupConfig | None = None) -> None: + super().setUp(config=config) self.test_get_url = "/get-code?eppn=pepin-pepon" self.test_app_domain = "test.localhost" @@ -79,13 +75,13 @@ def load_app(self, config: Mapping[str, Any]) -> BackdoorTestApp: app.session_interface = SessionFactory(app.conf) return app - def test_backdoor_get_code(self): + def test_backdoor_get_code(self) -> None: """""" with self.session_cookie_and_magic_cookie_anon(self.browser) as client: response = client.get(self.test_get_url) assert response.data == b"dummy-code-for-pepin-pepon" - def test_no_backdoor_in_pro(self): + def test_no_backdoor_in_pro(self) -> None: """""" self.app.conf.environment = EduidEnvironment("production") @@ -93,19 +89,19 @@ def test_no_backdoor_in_pro(self): response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_backdoor_without_cookie(self): + def test_no_backdoor_without_cookie(self) -> None: """""" with self.session_cookie_anon(self.browser) as client: response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_wrong_cookie_no_backdoor(self): + def test_wrong_cookie_no_backdoor(self) -> None: """""" with self.session_cookie_and_magic_cookie_anon(self.browser, magic_cookie_value="no-magic") as client: response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_magic_cookie_no_backdoor(self): + def test_no_magic_cookie_no_backdoor(self) -> None: """""" self.app.conf.magic_cookie = "" @@ -113,7 +109,7 @@ def test_no_magic_cookie_no_backdoor(self): response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_magic_cookie_name_no_backdoor(self): + def test_no_magic_cookie_name_no_backdoor(self) -> None: """""" self.app.conf.magic_cookie_name = "" diff --git a/src/eduid/webapp/common/api/tests/test_decorators.py b/src/eduid/webapp/common/api/tests/test_decorators.py index cc9218896..339531224 100644 --- a/src/eduid/webapp/common/api/tests/test_decorators.py +++ b/src/eduid/webapp/common/api/tests/test_decorators.py @@ -20,7 +20,7 @@ class DecoratorTestConfig(EduIDBaseAppConfig): class DecoratorTestApp(EduIDBaseApp): - def __init__(self, config: DecoratorTestConfig): + def __init__(self, config: DecoratorTestConfig) -> None: super().__init__(config) self.conf = config @@ -49,7 +49,7 @@ def load_app(self, config: Mapping[str, Any]) -> DecoratorTestApp: app.session_interface = SessionFactory(app.conf) return app - def test_success_message(self): + def test_success_message(self) -> None: """Test that a simple success_message is turned into a well-formed Flux Standard Action response""" msg = success_response(message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): @@ -59,7 +59,7 @@ def test_success_message(self): "payload": {"message": "test.first_msg", "success": True}, } - def test_success_message_with_data(self): + def test_success_message_with_data(self) -> None: """Test that a success_message with data is turned into a well-formed Flux Standard Action response""" msg = success_response(payload={"working": True}, message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): @@ -69,7 +69,7 @@ def test_success_message_with_data(self): "payload": {"message": "test.first_msg", "success": True, "working": True}, } - def test_error_message(self): + def test_error_message(self) -> None: """Test that a simple success_message is turned into a well-formed Flux Standard Action response""" msg = error_response(message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): diff --git a/src/eduid/webapp/common/api/tests/test_inputs.py b/src/eduid/webapp/common/api/tests/test_inputs.py index 1185d6860..b9eb4b7c3 100644 --- a/src/eduid/webapp/common/api/tests/test_inputs.py +++ b/src/eduid/webapp/common/api/tests/test_inputs.py @@ -1,9 +1,9 @@ import logging from collections.abc import Mapping -from typing import Any +from typing import Any, NoReturn from urllib.parse import unquote -from flask import Blueprint, make_response, request +from flask import Blueprint, Response, make_response, request from marshmallow import ValidationError, fields from werkzeug.http import dump_cookie @@ -21,7 +21,7 @@ __author__ = "lundberg" -def dont_validate(value): +def dont_validate(value: str | bytes) -> NoReturn: raise ValidationError(f"Problem with {value!r}") @@ -35,7 +35,7 @@ class Meta: test_views = Blueprint("test", __name__) -def _make_response(data): +def _make_response(data: str) -> Response: html = f"{data}" response = make_response(html, 200) response.headers["Content-Type"] = "text/html; charset=utf8" @@ -43,50 +43,56 @@ def _make_response(data): @test_views.route("/test-get-param", methods=["GET"]) -def get_param_view(): +def get_param_view() -> Response: param = request.args.get("test-param") + assert param return _make_response(param) @test_views.route("/test-post-param", methods=["POST"]) -def post_param_view(): +def post_param_view() -> Response: param = request.form.get("test-param") + assert param return _make_response(param) -@test_views.route("/test-post-json", methods=["POST"]) +@test_views.route("/test-post-json", methods=["POST"]) # type: ignore[arg-type] @UnmarshalWith(NonValidatingSchema) -def post_json_view(test_data): +def post_json_view(test_data: str) -> None: """never validates""" pass @test_views.route("/test-cookie") -def cookie_view(): +def cookie_view() -> Response: cookie = request.cookies.get("test-cookie") + assert cookie return _make_response(cookie) @test_views.route("/test-empty-session") -def empty_session_view(): +def empty_session_view() -> Response: cookie = request.cookies.get("sessid") + assert cookie is not None return _make_response(cookie) @test_views.route("/test-header") -def header_view(): +def header_view() -> Response: header = request.headers.get("X-TEST") + assert header return _make_response(header) @test_views.route("/test-values", methods=["GET", "POST"]) -def values_view(): +def values_view() -> Response: param = request.values.get("test-param") + assert param return _make_response(param) class InputsTestApp(EduIDBaseApp): - def __init__(self, config: EduIDBaseAppConfig): + def __init__(self, config: EduIDBaseAppConfig) -> None: super().__init__(config) self.conf = config @@ -112,27 +118,27 @@ def load_app(self, test_config: Mapping[str, Any]) -> InputsTestApp: app.register_blueprint(test_views) return app - def test_get_param(self): + def test_get_param(self) -> None: """""" url = "/test-get-param?test-param=test-param" with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() self.assertIn(b"test-param", response.data) - def test_get_param_script(self): + def test_get_param_script(self) -> None: """""" url = '/test-get-param?test-param=' with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() self.assertNotIn(b"'}): response = self.app.dispatch_request() self.assertNotIn(b"') @@ -205,7 +211,7 @@ def test_cookie_script(self): response = self.app.dispatch_request() self.assertNotIn(b"' @@ -213,21 +219,21 @@ def test_header_script(self): response = self.app.dispatch_request() self.assertNotIn(b"'}): response = self.app.dispatch_request() self.assertNotIn(b"