From 9bd24244e002dc72f73a2ecb27015f7d4ccd8ef3 Mon Sep 17 00:00:00 2001 From: Lasse Yledahl Date: Fri, 27 Sep 2024 12:25:58 +0000 Subject: [PATCH 01/16] add type annotations for function arguments --- ruff.toml | 2 +- .../clients/amapi_client/amapi_client.py | 2 +- src/eduid/common/config/exceptions.py | 2 +- src/eduid/common/config/parsers/decorators.py | 6 +-- src/eduid/common/config/parsers/exceptions.py | 2 +- src/eduid/common/decorators.py | 5 +- src/eduid/common/fastapi/utils.py | 2 +- src/eduid/common/logging.py | 8 +-- src/eduid/common/models/scim_base.py | 2 +- src/eduid/common/stats/__init__.py | 11 ++-- src/eduid/graphdb/groupdb/db.py | 26 ++++++--- src/eduid/graphdb/tests/test_db.py | 4 +- src/eduid/graphdb/tests/test_groupdb.py | 2 +- src/eduid/maccapi/middleware.py | 3 +- src/eduid/queue/db/worker.py | 2 +- src/eduid/queue/decorators.py | 21 +++++--- src/eduid/queue/tests/test_mail_worker.py | 6 +-- src/eduid/queue/workers/base.py | 2 +- src/eduid/satosa/scimapi/serve_static.py | 6 ++- src/eduid/scimapi/middleware.py | 14 ++--- src/eduid/scimapi/tests/test_scimevent.py | 4 +- src/eduid/scimapi/tests/test_scimgroup.py | 5 +- src/eduid/scimapi/tests/test_sciminvite.py | 7 ++- src/eduid/scimapi/tests/test_scimuser.py | 6 +-- src/eduid/userdb/authninfo.py | 2 +- src/eduid/userdb/credentials/external.py | 2 +- src/eduid/userdb/credentials/password.py | 4 +- src/eduid/userdb/element.py | 2 +- src/eduid/userdb/locked_identity.py | 3 +- src/eduid/userdb/mail.py | 4 +- src/eduid/userdb/personal_data/db.py | 2 +- src/eduid/userdb/scimapi/groupdb.py | 4 +- src/eduid/userdb/scimapi/userdb.py | 2 +- src/eduid/userdb/support/models.py | 21 ++++---- src/eduid/userdb/testing/__init__.py | 2 +- src/eduid/userdb/tests/test_credentials.py | 2 +- src/eduid/userdb/tests/test_u2f.py | 2 +- src/eduid/userdb/tests/test_user.py | 2 +- src/eduid/userdb/tests/test_webauthn.py | 2 +- src/eduid/userdb/tou.py | 2 +- src/eduid/userdb/util.py | 6 +-- src/eduid/vccs/client/__init__.py | 16 ++++-- src/eduid/vccs/client/tests/test_client.py | 17 +++--- src/eduid/vccs/server/endpoints/add_creds.py | 4 +- src/eduid/vccs/server/hasher.py | 53 ++++++++++--------- src/eduid/vccs/server/log.py | 3 +- src/eduid/vccs/server/run.py | 4 +- src/eduid/webapp/authn/tests/test_authn.py | 3 +- src/eduid/webapp/common/api/debug.py | 13 +++-- src/eduid/webapp/common/api/exceptions.py | 9 ++-- src/eduid/webapp/common/api/middleware.py | 10 +++- src/eduid/webapp/common/api/request.py | 14 ++--- src/eduid/webapp/common/api/schemas/csrf.py | 9 ++-- src/eduid/webapp/common/api/schemas/email.py | 16 +++--- src/eduid/webapp/common/api/schemas/models.py | 17 ++++-- .../webapp/common/api/schemas/validators.py | 4 +- .../webapp/common/api/tests/test_inputs.py | 6 +-- src/eduid/webapp/common/api/validation.py | 2 +- src/eduid/webapp/common/authn/middleware.py | 6 +-- src/eduid/webapp/common/authn/testing.py | 53 +++++-------------- .../common/authn/tests/test_fido_tokens.py | 6 +-- .../common/authn/tests/test_middleware.py | 8 +-- .../webapp/common/authn/tests/test_vccs.py | 2 +- src/eduid/webapp/common/authn/utils.py | 2 +- src/eduid/webapp/common/authn/vccs.py | 2 +- .../webapp/common/proofing/saml_helpers.py | 3 +- .../webapp/common/session/eduid_session.py | 8 +-- .../webapp/common/session/redis_session.py | 8 +-- .../session/tests/test_eduid_session.py | 4 +- src/eduid/webapp/common/wsgi.py | 15 ++++++ src/eduid/webapp/email/validators.py | 4 +- src/eduid/webapp/email/verifications.py | 6 ++- src/eduid/webapp/email/views.py | 5 +- src/eduid/webapp/freja_eid/views.py | 2 +- src/eduid/webapp/idp/decorators.py | 5 +- src/eduid/webapp/idp/settings/common.py | 5 +- src/eduid/webapp/idp/tests/test_SSO.py | 22 ++++---- src/eduid/webapp/idp/tests/test_api.py | 4 +- src/eduid/webapp/idp/tests/test_idPUserDb.py | 16 ++++-- src/eduid/webapp/letter_proofing/ekopost.py | 22 +++++--- .../webapp/letter_proofing/tests/test_app.py | 2 +- .../webapp/letter_proofing/tests/test_pdf.py | 5 +- src/eduid/webapp/letter_proofing/views.py | 2 +- .../webapp/oidc_proofing/tests/test_app.py | 13 +++-- src/eduid/webapp/personal_data/views.py | 2 +- src/eduid/webapp/phone/schemas.py | 2 +- src/eduid/webapp/phone/validators.py | 10 ++-- src/eduid/webapp/phone/views.py | 2 +- .../webapp/reset_password/tests/test_app.py | 14 ++--- src/eduid/webapp/security/helpers.py | 2 +- src/eduid/webapp/security/schemas.py | 2 +- .../security/tests/test_change_password.py | 14 ++--- .../webapp/security/tests/test_webauthn.py | 10 ++-- src/eduid/webapp/signup/schemas.py | 10 ++-- src/eduid/webapp/support/app.py | 7 +-- src/eduid/webapp/svipe_id/helpers.py | 2 +- src/eduid/webapp/svipe_id/views.py | 2 +- src/eduid/workers/am/tasks.py | 3 +- src/eduid/workers/am/testing.py | 4 +- src/eduid/workers/am/tests/test_am.py | 12 +++-- src/eduid/workers/amapi/middleware.py | 4 +- .../workers/amapi/routers/utils/status.py | 2 +- .../client/mobile_lookup_client.py | 12 +++-- src/eduid/workers/lookup_mobile/decorators.py | 29 ++++++---- .../development/development_search_result.py | 9 ++-- .../development/nin_mobile_db.py | 4 +- src/eduid/workers/lookup_mobile/tasks.py | 2 +- .../lookup_mobile/test/test_decorators.py | 14 ++--- src/eduid/workers/lookup_mobile/testing.py | 5 +- src/eduid/workers/msg/decorators.py | 2 +- src/eduid/workers/msg/exceptions.py | 2 +- src/eduid/workers/msg/tasks.py | 3 +- src/eduid/workers/msg/testing.py | 3 +- src/eduid/workers/msg/tests/__init__.py | 15 ------ .../workers/msg/tests/test_decorators.py | 19 +++++-- src/eduid/workers/msg/tests/test_tasks.py | 11 ++-- 116 files changed, 496 insertions(+), 382 deletions(-) create mode 100644 src/eduid/webapp/common/wsgi.py diff --git a/ruff.toml b/ruff.toml index 3b563fd09..74a132870 100644 --- a/ruff.toml +++ b/ruff.toml @@ -3,6 +3,6 @@ line-length = 120 target-version = "py310" [lint] -select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA"] +select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN001"] ignore = ["E501"] diff --git a/src/eduid/common/clients/amapi_client/amapi_client.py b/src/eduid/common/clients/amapi_client/amapi_client.py index bcf6a1b6f..1d726debe 100644 --- a/src/eduid/common/clients/amapi_client/amapi_client.py +++ b/src/eduid/common/clients/amapi_client/amapi_client.py @@ -20,7 +20,7 @@ class AMAPIClient(GNAPClient): - def __init__(self, amapi_url: str, auth_data=GNAPClientAuthData, verify_tls: bool = True, **kwargs): + def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs): super().__init__(auth_data=auth_data, verify=verify_tls, **kwargs) self.amapi_url = amapi_url diff --git a/src/eduid/common/config/exceptions.py b/src/eduid/common/config/exceptions.py index 5c7d7d097..f815dcead 100644 --- a/src/eduid/common/config/exceptions.py +++ b/src/eduid/common/config/exceptions.py @@ -1,5 +1,5 @@ class BadConfiguration(Exception): - def __init__(self, message): + def __init__(self, message: str): Exception.__init__(self) self.value = message diff --git a/src/eduid/common/config/parsers/decorators.py b/src/eduid/common/config/parsers/decorators.py index ce686111f..39caa8fb3 100644 --- a/src/eduid/common/config/parsers/decorators.py +++ b/src/eduid/common/config/parsers/decorators.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import wraps from string import Template from typing import Any @@ -11,7 +11,7 @@ from eduid.common.config.parsers.exceptions import SecretKeyException -def decrypt(f): +def decrypt(f: Callable): @wraps(f) def decrypt_decorator(*args, **kwargs): config_dict = f(*args, **kwargs) @@ -83,7 +83,7 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]: return new_config_dict -def interpolate(f): +def interpolate(f: Callable): @wraps(f) def interpolation_decorator(*args, **kwargs): config_dict = f(*args, **kwargs) diff --git a/src/eduid/common/config/parsers/exceptions.py b/src/eduid/common/config/parsers/exceptions.py index eaecc6f93..a203d8830 100644 --- a/src/eduid/common/config/parsers/exceptions.py +++ b/src/eduid/common/config/parsers/exceptions.py @@ -2,7 +2,7 @@ class ParserException(Exception): - def __init__(self, message): + def __init__(self, message: str): Exception.__init__(self) self.value = message diff --git a/src/eduid/common/decorators.py b/src/eduid/common/decorators.py index 2c0e62e66..661613b01 100644 --- a/src/eduid/common/decorators.py +++ b/src/eduid/common/decorators.py @@ -1,10 +1,11 @@ import inspect import warnings +from collections.abc import Callable from functools import wraps # https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488 -def deprecated(reason): +def deprecated(reason: str | type | Callable): """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted @@ -20,7 +21,7 @@ def deprecated(reason): # def old_function(x, y): # pass - def decorator(func1): + def decorator(func1: Callable): if inspect.isclass(func1): fmt1 = "Call to deprecated class {name} ({reason})." else: diff --git a/src/eduid/common/fastapi/utils.py b/src/eduid/common/fastapi/utils.py index 820878ab8..6b5695934 100644 --- a/src/eduid/common/fastapi/utils.py +++ b/src/eduid/common/fastapi/utils.py @@ -48,7 +48,7 @@ def reset_failure_info(req: ContextRequest, key: str) -> None: req.app.context.logger.info(f"Check {key} back to normal. Resetting info {info}") -def check_restart(key, restart: int, terminate: int) -> bool: +def check_restart(key: str, restart: int, terminate: int) -> bool: res = False # default to no restart info = FAILURE_INFO.get(key) if not info: diff --git a/src/eduid/common/logging.py b/src/eduid/common/logging.py index 84948ac5b..d1a5582e5 100644 --- a/src/eduid/common/logging.py +++ b/src/eduid/common/logging.py @@ -26,11 +26,11 @@ # Default to RFC3339/ISO 8601 with tz class EduidFormatter(logging.Formatter): - def __init__(self, relative_time: bool = False, fmt=None): + def __init__(self, relative_time: bool = False, fmt: str | None = None): super().__init__(fmt=fmt, style="{") self._relative_time = relative_time - def formatTime(self, record: logging.LogRecord, datefmt=None) -> str: + def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: if self._relative_time: # Relative time makes much more sense than absolute time when running tests for example _seconds = record.relativeCreated / 1000 @@ -52,7 +52,7 @@ def formatTime(self, record: logging.LogRecord, datefmt=None) -> str: class AppFilter(logging.Filter): """Add `system_hostname`, `hostname` and `app_name` to records being logged.""" - def __init__(self, app_name): + def __init__(self, app_name: str): super().__init__() self.app_name = app_name # TODO: I guess it could be argued that these should be put in the LocalContext and not evaluated at runtime. @@ -139,7 +139,7 @@ def filter(self, record: logging.LogRecord) -> bool: def merge_config(base_config: dict[str, Any], new_config: dict[str, Any]) -> dict[str, Any]: """Recursively merge two dictConfig dicts.""" - def merge(node, key, value): + def merge(node: dict[str, Any], key: str, value: Any): if isinstance(value, dict): for item in value: if key in node: diff --git a/src/eduid/common/models/scim_base.py b/src/eduid/common/models/scim_base.py index 243dcc6be..4ec826b6f 100644 --- a/src/eduid/common/models/scim_base.py +++ b/src/eduid/common/models/scim_base.py @@ -112,7 +112,7 @@ def is_group(self): return self.ref and "/Groups/" in self.ref @classmethod - def from_mapping(cls, data): + def from_mapping(cls, data: Any): return cls.model_validate(data) diff --git a/src/eduid/common/stats/__init__.py b/src/eduid/common/stats/__init__.py index 25e0cfa4e..122501cd0 100644 --- a/src/eduid/common/stats/__init__.py +++ b/src/eduid/common/stats/__init__.py @@ -14,6 +14,7 @@ __author__ = "ft" from abc import ABC, abstractmethod +from logging import Logger from eduid.common.config.base import StatsConfigMixin @@ -23,7 +24,7 @@ class AppStats(ABC): def count(self, name: str, value: int = 1) -> None: pass - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): pass @@ -35,7 +36,7 @@ class NoOpStats(AppStats): configured allows us to not check if current_app.stats is set everywhere. """ - def __init__(self, logger=None, prefix=None): + def __init__(self, logger: Logger | None = None, prefix: str | None = None): self.logger = logger self.prefix = prefix @@ -45,7 +46,7 @@ def count(self, name: str, value: int = 1) -> None: name = f"{self.prefix!s}.{name!s}" self.logger.info(f"No-op stats count: {name!r} {value!r}") - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): if self.logger: if self.prefix: name = f"{self.prefix!s}.{name!s}" @@ -53,7 +54,7 @@ def gauge(self, name: str, value: int, rate=1, delta=False): class Statsd(AppStats): - def __init__(self, host, port, prefix=None): + def __init__(self, host: str, port: int, prefix: str | None = None): import statsd self.client = statsd.StatsClient(host, port, prefix=prefix) @@ -64,7 +65,7 @@ def count(self, name: str, value: int = 1) -> None: # for .count self.client.incr(f"{name}.count", count=value) - def gauge(self, name: str, value: int, rate=1, delta=False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): self.client.gauge(f"{name}.gauge", value=value, rate=rate, delta=delta) diff --git a/src/eduid/graphdb/groupdb/db.py b/src/eduid/graphdb/groupdb/db.py index c38cf69a3..8dd4b5ce7 100644 --- a/src/eduid/graphdb/groupdb/db.py +++ b/src/eduid/graphdb/groupdb/db.py @@ -109,16 +109,28 @@ def _add_or_update_users_and_groups( for user_member in group.member_users: res = self._add_user_to_group(tx, group=group, member=user_member, role=Role.MEMBER) - members.add(User.from_mapping(res.data())) + if res: + members.add(User.from_mapping(res.data())) + else: + logger.info(f"User {user_member.identifier} not added to group {group.identifier}.") for group_member in group.member_groups: res = self._add_group_to_group(tx, group=group, member=group_member, role=Role.MEMBER) - members.add(self._load_group(res.data())) + if res: + members.add(self._load_group(res.data())) + else: + logger.info(f"Group {group_member.identifier} not added to group {group.identifier}.") for user_owner in group.owner_users: res = self._add_user_to_group(tx, group=group, member=user_owner, role=Role.OWNER) - owners.add(User.from_mapping(res.data())) + if res: + owners.add(User.from_mapping(res.data())) + else: + logger.info(f"User {user_owner.identifier} not added to group {group.identifier}.") for group_owner in group.owner_groups: res = self._add_group_to_group(tx, group=group, member=group_owner, role=Role.OWNER) - owners.add(self._load_group(res.data())) + if res: + owners.add(self._load_group(res.data())) + else: + logger.info(f"Group {group_owner.identifier} not added to group {group.identifier}.") return members, owners def _remove_missing_users_and_groups(self, tx: Transaction, group: Group, role: Role) -> None: @@ -161,7 +173,7 @@ def _remove_user_from_group(self, tx: Transaction, group: Group, user_identifier """ tx.run(q, scope=self.scope, identifier=group.identifier, user_identifier=user_identifier) - def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Record: + def _add_group_to_group(self, tx: Transaction, group: Group, member: Group, role: Role) -> Record | None: q = f""" MATCH (g:Group {{scope: $scope, identifier: $group_identifier}}) MERGE (m:Group {{scope: $scope, identifier: $identifier}}) @@ -187,7 +199,7 @@ def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Re display_name=member.display_name, ).single() - def _add_user_to_group(self, tx, group: Group, member: User, role: Role) -> Record: + def _add_user_to_group(self, tx: Transaction, group: Group, member: User, role: Role) -> Record | None: q = f""" MATCH (g:Group {{scope: $scope, identifier: $group_identifier}}) MERGE (m:User {{identifier: $identifier}}) @@ -276,7 +288,7 @@ def remove_group(self, identifier: str) -> None: with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session: session.run(q, scope=self.scope, identifier=identifier) - def get_groups_by_property(self, key: str, value: str, skip=0, limit=100): + def get_groups_by_property(self, key: str, value: str, skip: int = 0, limit: int = 100): res: list[Group] = [] q = f""" MATCH (g: Group {{scope: $scope}}) diff --git a/src/eduid/graphdb/tests/test_db.py b/src/eduid/graphdb/tests/test_db.py index 11f4669b0..b5f4a44a2 100644 --- a/src/eduid/graphdb/tests/test_db.py +++ b/src/eduid/graphdb/tests/test_db.py @@ -1,3 +1,5 @@ +from typing import Any + from neo4j import basic_auth from eduid.graphdb.db import BaseGraphDB @@ -17,7 +19,7 @@ def test_create_db(self): class TestBaseGraphDB(Neo4jTestCase): class TestDB(BaseGraphDB): - def __init__(self, db_uri, config=None): + def __init__(self, db_uri: str, config: dict[str, Any] | None = None): super().__init__(db_uri, config=config) def db_setup(self): diff --git a/src/eduid/graphdb/tests/test_groupdb.py b/src/eduid/graphdb/tests/test_groupdb.py index 1fc713d21..50602fef0 100644 --- a/src/eduid/graphdb/tests/test_groupdb.py +++ b/src/eduid/graphdb/tests/test_groupdb.py @@ -29,7 +29,7 @@ def setUp(self) -> None: self.user2: dict[str, str] = {"identifier": "user2", "display_name": "Namn Namnsson"} @staticmethod - def _assert_group(expected: Group, testing: Group, modified=False): + def _assert_group(expected: Group, testing: Group, modified: bool = False): assert expected.identifier == testing.identifier assert expected.display_name == testing.display_name assert testing.created_ts is not None diff --git a/src/eduid/maccapi/middleware.py b/src/eduid/maccapi/middleware.py index 4595b38c7..cb896897e 100644 --- a/src/eduid/maccapi/middleware.py +++ b/src/eduid/maccapi/middleware.py @@ -7,6 +7,7 @@ from jwcrypto.common import JWException from pydantic import ValidationError from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp from eduid.common.fastapi.context_request import ContextRequestMixin from eduid.common.models.bearer_token import AuthnBearerToken, RequestedAccessDenied @@ -22,7 +23,7 @@ def return_error_response(status_code: int, detail: str) -> JSONResponse: class AuthenticationMiddleware(BaseHTTPMiddleware, ContextRequestMixin): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context): super().__init__(app) self.context = context diff --git a/src/eduid/queue/db/worker.py b/src/eduid/queue/db/worker.py index a4e7357e0..8fd083e7e 100644 --- a/src/eduid/queue/db/worker.py +++ b/src/eduid/queue/db/worker.py @@ -40,7 +40,7 @@ def parse_queue_item(self, doc: Mapping, parse_payload: bool = True): return item return replace(item, payload=self._load_payload(item)) - async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab=False) -> QueueItem | None: + async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab: bool = False) -> QueueItem | None: """ :param item_id: document id :param worker_name: current workers name diff --git a/src/eduid/queue/decorators.py b/src/eduid/queue/decorators.py index a2a8eafa6..b9b84ba8d 100644 --- a/src/eduid/queue/decorators.py +++ b/src/eduid/queue/decorators.py @@ -1,22 +1,27 @@ +from collections.abc import Callable from inspect import isclass +from typing import Any + +from pymongo.synchronous.collection import Collection from eduid.userdb.db import MongoDB # TODO: Refactor but keep transaction audit document structure +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.util import utc_now class TransactionAudit: enabled = False - def __init__(self, db_uri, db_name="eduid_queue", collection_name="transaction_audit"): - self._conn = None - self.db_uri = db_uri - self.db_name = db_name - self.collection_name = collection_name - self.collection = None + def __init__(self, db_uri: str, db_name: str = "eduid_queue", collection_name: str = "transaction_audit"): + self._conn: MongoDB | None = None + self.db_uri: str = db_uri + self.db_name: str = db_name + self.collection_name: str = collection_name + self.collection: Collection[TUserDbDocument] | None = None - def __call__(self, f): + def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f @@ -47,7 +52,7 @@ def disable(cls): cls.enabled = False @staticmethod - def _filter(func, data, *args, **kwargs): + def _filter(func: str, data: Any, *args, **kwargs): if data is False: return data if func == "_get_navet_data": diff --git a/src/eduid/queue/tests/test_mail_worker.py b/src/eduid/queue/tests/test_mail_worker.py index 023d82a87..635751f7c 100644 --- a/src/eduid/queue/tests/test_mail_worker.py +++ b/src/eduid/queue/tests/test_mail_worker.py @@ -3,7 +3,7 @@ import os from datetime import timedelta from os import environ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from aiosmtplib import SMTPResponse @@ -82,7 +82,7 @@ async def test_eduid_signup_mail_from_stream(self): await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail): + async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail: MagicMock): """ Test that saved queue items are handled by the handle_new_item method """ @@ -99,7 +99,7 @@ async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_send await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail): + async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: MagicMock): """ Test that saved queue items are handled by the handle_new_item method """ diff --git a/src/eduid/queue/workers/base.py b/src/eduid/queue/workers/base.py index 5b50883f7..93415f729 100644 --- a/src/eduid/queue/workers/base.py +++ b/src/eduid/queue/workers/base.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def cancel_task(signame, task): +def cancel_task(signame: str, task: Task): logger.info(f"got signal {signame}: exit") task.cancel() diff --git a/src/eduid/satosa/scimapi/serve_static.py b/src/eduid/satosa/scimapi/serve_static.py index 6a5e2bbec..a2bf79a9a 100644 --- a/src/eduid/satosa/scimapi/serve_static.py +++ b/src/eduid/satosa/scimapi/serve_static.py @@ -5,8 +5,10 @@ import logging import mimetypes +from satosa.context import Context from satosa.micro_services.base import RequestMicroService from satosa.response import Response +from satosa.satosa_config import SATOSAConfig logger = logging.getLogger("satosa") @@ -27,7 +29,7 @@ class ServeStatic(RequestMicroService): logprefix = "SERVE_STATIC_SERVICE:" - def __init__(self, config, *args, **kwargs): + def __init__(self, config: SATOSAConfig, *args, **kwargs): """ :type config: satosa.satosa_config.SATOSAConfig :param config: The SATOSA proxy config @@ -43,7 +45,7 @@ def register_endpoints(self): url_map.append([f"^{endpoint}/", self._handle]) return url_map - def _handle(self, context): + def _handle(self, context: Context): path = context._path endpoint = path.split("/")[0] target = path[len(endpoint) + 1 :] diff --git a/src/eduid/scimapi/middleware.py b/src/eduid/scimapi/middleware.py index b87277abf..05139729f 100644 --- a/src/eduid/scimapi/middleware.py +++ b/src/eduid/scimapi/middleware.py @@ -7,8 +7,8 @@ from jwcrypto.common import JWException from pydantic import ValidationError from starlette.datastructures import URL -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.types import Message +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp, Message from eduid.common.config.base import DataOwnerName from eduid.common.fastapi.context_request import ContextRequestMixin @@ -43,16 +43,16 @@ async def get_body(request: Request) -> bytes: class BaseMiddleware(BaseHTTPMiddleware, ContextRequestMixin): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context): super().__init__(app) self.context = context - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(req) class ScimMiddleware(BaseMiddleware): - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req, context_class=ScimApiContext) self.context.logger.debug(f"process_request: {req.method} {req.url.path}") resp = await call_next(req) @@ -62,7 +62,7 @@ async def dispatch(self, req: Request, call_next) -> Response: class AuthenticationMiddleware(BaseMiddleware): - def __init__(self, app, context: Context): + def __init__(self, app: ASGIApp, context: Context): super().__init__(app, context) self.no_authn_urls = self.context.config.no_authn_urls self.context.logger.debug(f"No auth allow urls: {self.no_authn_urls}") @@ -78,7 +78,7 @@ def _is_no_auth_path(self, url: URL) -> bool: return True return False - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req, context_class=ScimApiContext) if self._is_no_auth_path(req.url): diff --git a/src/eduid/scimapi/tests/test_scimevent.py b/src/eduid/scimapi/tests/test_scimevent.py index 27fcbb12b..931ddc9dc 100644 --- a/src/eduid/scimapi/tests/test_scimevent.py +++ b/src/eduid/scimapi/tests/test_scimevent.py @@ -49,11 +49,11 @@ def _fetch_event(self, event_id: UUID) -> EventApiResult: parsed_response = EventResponse.parse_raw(response.text) return EventApiResult(event=parsed_response.nutid_event_v1, response=response, parsed_response=parsed_response) - def _assertEventUpdateSuccess(self, req: Mapping, response, event: ScimApiEvent): + def _assertEventUpdateSuccess(self, req: Mapping, response: Response, event: ScimApiEvent): """Function to validate successful responses to SCIM calls that update a event according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json()}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json()}") expected_schemas = req.get("schemas", [SCIMSchema.NUTID_EVENT_CORE_V1.value]) if ( diff --git a/src/eduid/scimapi/tests/test_scimgroup.py b/src/eduid/scimapi/tests/test_scimgroup.py index d098d303d..221e86b2c 100644 --- a/src/eduid/scimapi/tests/test_scimgroup.py +++ b/src/eduid/scimapi/tests/test_scimgroup.py @@ -7,6 +7,7 @@ from uuid import UUID, uuid4 from bson import ObjectId +from httpx import Response from eduid.common.config.base import DataOwnerName from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema, WeakVersion @@ -145,10 +146,10 @@ def _perform_search( return resources - def _assertGroupUpdateSuccess(self, req: Mapping, response, group: ScimApiGroup): + def _assertGroupUpdateSuccess(self, req: Mapping, response: Response, group: ScimApiGroup): """Function to validate successful responses to SCIM calls that update a group according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.CORE_20_GROUP.value]) if ( diff --git a/src/eduid/scimapi/tests/test_sciminvite.py b/src/eduid/scimapi/tests/test_sciminvite.py index d049e7d41..71978c0c9 100644 --- a/src/eduid/scimapi/tests/test_sciminvite.py +++ b/src/eduid/scimapi/tests/test_sciminvite.py @@ -8,6 +8,7 @@ from typing import Any from bson import ObjectId +from httpx import Response from eduid.common.misc.timeutil import utc_now from eduid.common.models.scim_base import Email, Meta, Name, PhoneNumber, SCIMResourceType, SCIMSchema @@ -239,10 +240,12 @@ def add_invite(self, data: dict[str, Any] | None = None, update: bool = False) - self.signup_invitedb.save(signup_invite, is_in_database=False) return db_invite - def _assertUpdateSuccess(self, req: Mapping, response, invite: ScimApiInvite, signup_invite: SignupInvite): + def _assertUpdateSuccess( + self, req: Mapping, response: Response, invite: ScimApiInvite, signup_invite: SignupInvite + ): """Function to validate successful responses to SCIM calls that update an invite according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.NUTID_INVITE_V1.value, SCIMSchema.NUTID_USER_V1.value]) diff --git a/src/eduid/scimapi/tests/test_scimuser.py b/src/eduid/scimapi/tests/test_scimuser.py index aad4f48c2..17ac5d915 100644 --- a/src/eduid/scimapi/tests/test_scimuser.py +++ b/src/eduid/scimapi/tests/test_scimuser.py @@ -189,11 +189,11 @@ def setUp(self) -> None: attributes={"displayName": "Test User 2"}, data={"another_test_key": "another_test_value"} ) - def _assertUserUpdateSuccess(self, req: Mapping, response, user: ScimApiUser): + def _assertUserUpdateSuccess(self, req: Mapping, response: Response, user: ScimApiUser): """Function to validate successful responses to SCIM calls that update a user according to a request.""" if response.json().get("schemas") == [SCIMSchema.ERROR.value]: - self.fail(f"Got SCIM error parsed_response ({response.status}):\n{response.json}") + self.fail(f"Got SCIM error parsed_response ({response.status_code}):\n{response.json}") expected_schemas = req.get("schemas", [SCIMSchema.CORE_20_USER.value]) if SCIMSchema.NUTID_USER_V1.value in response.json() and SCIMSchema.NUTID_USER_V1.value not in expected_schemas: @@ -226,7 +226,7 @@ def _assertUserUpdateSuccess(self, req: Mapping, response, user: ScimApiUser): resp_nutid, "Unexpected NUTID user data in parsed_response", ) - elif SCIMSchema.NUTID_USER_V1.value in response.json: + elif SCIMSchema.NUTID_USER_V1.value in response.json(): self.fail(f"Unexpected {SCIMSchema.NUTID_USER_V1.value} in the parsed_response") def _create_user(self, req: dict[str, Any], expect_success: bool = True) -> UserApiResult: diff --git a/src/eduid/userdb/authninfo.py b/src/eduid/userdb/authninfo.py index 1b3f92cb3..04d043808 100644 --- a/src/eduid/userdb/authninfo.py +++ b/src/eduid/userdb/authninfo.py @@ -33,7 +33,7 @@ class AuthnInfoDB(BaseDB): TODO: We already have a database class to access this collection, in the IdP. Consolidate the two. """ - def __init__(self, db_uri, db_name="eduid_idp_authninfo", collection="authn_info"): + def __init__(self, db_uri: str, db_name: str = "eduid_idp_authninfo", collection: str = "authn_info"): super().__init__(db_uri, db_name, collection) def get_authn_info(self, user: User) -> Mapping[ElementKey, AuthnInfoElement]: diff --git a/src/eduid/userdb/credentials/external.py b/src/eduid/userdb/credentials/external.py index 044c3bce0..60371c4ca 100644 --- a/src/eduid/userdb/credentials/external.py +++ b/src/eduid/userdb/credentials/external.py @@ -26,7 +26,7 @@ class ExternalCredential(Credential): @field_validator("credential_id", mode="before") @classmethod - def credential_id_objectid(cls, v): + def credential_id_objectid(cls, v: Any): """Turn ObjectId into string""" if isinstance(v, ObjectId): v = str(v) diff --git a/src/eduid/userdb/credentials/password.py b/src/eduid/userdb/credentials/password.py index 291bda4b4..af69481ef 100644 --- a/src/eduid/userdb/credentials/password.py +++ b/src/eduid/userdb/credentials/password.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from bson import ObjectId from pydantic import Field, field_validator @@ -17,7 +19,7 @@ class Password(Credential): @field_validator("credential_id", mode="before") @classmethod - def credential_id_objectid(cls, v): + def credential_id_objectid(cls, v: Any): """Turn ObjectId into string""" if isinstance(v, ObjectId): v = str(v) diff --git a/src/eduid/userdb/element.py b/src/eduid/userdb/element.py index 758de0024..eb7e5136f 100644 --- a/src/eduid/userdb/element.py +++ b/src/eduid/userdb/element.py @@ -312,7 +312,7 @@ def _validate_elements(cls, values: list[ListElement]) -> list[ListElement]: return values @classmethod - def from_list_of_dicts(cls, items): + def from_list_of_dicts(cls, items: list[dict[str, Any]]): # must be implemented by subclass to get correct type information raise NotImplementedError() diff --git a/src/eduid/userdb/locked_identity.py b/src/eduid/userdb/locked_identity.py index dad1b1293..31e648a8a 100644 --- a/src/eduid/userdb/locked_identity.py +++ b/src/eduid/userdb/locked_identity.py @@ -4,6 +4,7 @@ from pydantic import field_validator +from eduid.userdb.element import ElementKey from eduid.userdb.exceptions import EduIDUserDBError from eduid.userdb.identity import IdentityElement, IdentityList @@ -34,7 +35,7 @@ def replace(self, element: IdentityElement) -> None: self.add(element=element) return None - def remove(self, key): + def remove(self, key: ElementKey): """ Override remove method as an element should be set once, remove never. """ diff --git a/src/eduid/userdb/mail.py b/src/eduid/userdb/mail.py index c35316c40..66b3fa618 100644 --- a/src/eduid/userdb/mail.py +++ b/src/eduid/userdb/mail.py @@ -14,7 +14,7 @@ class MailAddress(PrimaryElement): @field_validator("email", mode="before") @classmethod - def validate_email(cls, v): + def validate_email(cls, v: Any): if not isinstance(v, str): raise ValueError("must be a string") return v.lower() @@ -53,7 +53,7 @@ def from_list_of_dicts(cls: type[MailAddressList], items: list[dict[str, Any]]) return cls(elements=[MailAddress.from_dict(this) for this in items]) -def address_from_dict(data): +def address_from_dict(data: dict[str, Any]): """ Create a MailAddress instance from a dict. diff --git a/src/eduid/userdb/personal_data/db.py b/src/eduid/userdb/personal_data/db.py index 7ea3270db..d39582c68 100644 --- a/src/eduid/userdb/personal_data/db.py +++ b/src/eduid/userdb/personal_data/db.py @@ -10,7 +10,7 @@ class PersonalDataUserDB(UserDB[PersonalDataUser]): - def __init__(self, db_uri, db_name="eduid_personal_data", collection="profiles"): + def __init__(self, db_uri: str, db_name: str = "eduid_personal_data", collection: str = "profiles"): super().__init__(db_uri, db_name, collection=collection) @classmethod diff --git a/src/eduid/userdb/scimapi/groupdb.py b/src/eduid/userdb/scimapi/groupdb.py index 3b65792e9..30a6f968a 100644 --- a/src/eduid/userdb/scimapi/groupdb.py +++ b/src/eduid/userdb/scimapi/groupdb.py @@ -268,7 +268,9 @@ def get_group_by_display_name(self, display_name: str) -> ScimApiGroup | None: return group return None - def get_groups_by_property(self, key: str, value: str | int, skip=0, limit=100) -> tuple[list[ScimApiGroup], int]: + def get_groups_by_property( + self, key: str, value: str | int, skip: int = 0, limit: int = 100 + ) -> tuple[list[ScimApiGroup], int]: docs, count = self._get_documents_and_count_by_filter({key: value}, skip=skip, limit=limit) if not docs: return [], 0 diff --git a/src/eduid/userdb/scimapi/userdb.py b/src/eduid/userdb/scimapi/userdb.py index 396904a85..c41103246 100644 --- a/src/eduid/userdb/scimapi/userdb.py +++ b/src/eduid/userdb/scimapi/userdb.py @@ -72,7 +72,7 @@ def from_dict(cls: type[ScimApiUser], data: Mapping[str, Any]) -> ScimApiUser: class ScimApiUserDB(ScimApiBaseDB): - def __init__(self, db_uri: str, collection: str, db_name="eduid_scimapi", setup_indexes: bool = True): + def __init__(self, db_uri: str, collection: str, db_name: str = "eduid_scimapi", setup_indexes: bool = True): super().__init__(db_uri, db_name, collection=collection) if setup_indexes: # Create an index so that scim_id and external_id is unique per data owner diff --git a/src/eduid/userdb/support/models.py b/src/eduid/userdb/support/models.py index a88f585d7..7861361fb 100644 --- a/src/eduid/userdb/support/models.py +++ b/src/eduid/userdb/support/models.py @@ -1,4 +1,7 @@ from copy import deepcopy +from typing import Any + +from eduid.userdb.db.base import TUserDbDocument __author__ = "lundberg" @@ -8,7 +11,7 @@ class GenericFilterDict(dict): add_keys: list[str] | None = None remove_keys: list[str] | None = None - def __init__(self, data): + def __init__(self, data: Any | None): """ Create a filtered dict with white- or blacklisting of keys @@ -35,7 +38,7 @@ def __init__(self, data): class SupportUserFilter(GenericFilterDict): remove_keys = ["_id", "letter_proofing_data"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) super().__init__(_data) @@ -47,7 +50,7 @@ def __init__(self, data): class SupportSignupUserFilter(GenericFilterDict): remove_keys = ["_id", "letter_proofing_data"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) super().__init__(_data) @@ -89,7 +92,7 @@ class ToU(GenericFilterDict): class UserAuthnInfo(GenericFilterDict): add_keys = ["success_ts", "fail_count", "success_count"] - def __init__(self, data): + def __init__(self, data: dict[str, Any]): _data = deepcopy(data) # Remove months with 0 failures or successes for attrib in ["fail_count", "success_count"]: @@ -110,7 +113,7 @@ class UserActions(GenericFilterDict): class ProofingLogEntry(GenericFilterDict): add_keys = ["verified_data", "created_ts", "proofing_method", "proofing_version", "created_by", "vetting_by"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) # Rename the verified data key to verified_data verified_data_names = ["nin", "mail_address", "phone_number", "orcid"] @@ -129,7 +132,7 @@ class Nin(GenericFilterDict): class ProofingLetter(GenericFilterDict): add_keys = ["sent_ts", "is_sent", "address"] - def __init__(self, data): + def __init__(self, data: dict[str, Any]): _data = deepcopy(data) super().__init__(_data) self["nin"] = self.Nin(self["nin"]) @@ -142,7 +145,7 @@ class UserOidcProofing(GenericFilterDict): class Nin(GenericFilterDict): add_keys = ["created_ts", "number"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) super().__init__(_data) self["nin"] = self.Nin(self["nin"]) @@ -154,7 +157,7 @@ class UserEmailProofing(GenericFilterDict): class Verification(GenericFilterDict): add_keys = ["created_ts", "email"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) super().__init__(_data) self["verification"] = self.Verification(self["verification"]) @@ -166,7 +169,7 @@ class UserPhoneProofing(GenericFilterDict): class Verification(GenericFilterDict): add_keys = ["created_ts", "number"] - def __init__(self, data): + def __init__(self, data: TUserDbDocument): _data = deepcopy(data) super().__init__(_data) self["verification"] = self.Verification(self["verification"]) diff --git a/src/eduid/userdb/testing/__init__.py b/src/eduid/userdb/testing/__init__.py index 7627d0b24..c5689022d 100644 --- a/src/eduid/userdb/testing/__init__.py +++ b/src/eduid/userdb/testing/__init__.py @@ -82,7 +82,7 @@ class MongoTestCase(unittest.TestCase): A test can access the port using the attribute `port` """ - def setUp(self, *args: list[Any], am_users: list[User] | None = None, **kwargs: dict[str, Any]): + def setUp(self, am_users: list[User] | None = None): """ Test case initialization. :return: diff --git a/src/eduid/userdb/tests/test_credentials.py b/src/eduid/userdb/tests/test_credentials.py index 35c01d8b2..28bb015e0 100644 --- a/src/eduid/userdb/tests/test_credentials.py +++ b/src/eduid/userdb/tests/test_credentials.py @@ -46,7 +46,7 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]): return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["public_key"].encode("utf-8")).hexdigest() diff --git a/src/eduid/userdb/tests/test_u2f.py b/src/eduid/userdb/tests/test_u2f.py index 084698888..ffad00c48 100644 --- a/src/eduid/userdb/tests/test_u2f.py +++ b/src/eduid/userdb/tests/test_u2f.py @@ -38,7 +38,7 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]): return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["public_key"].encode("utf-8")).hexdigest() diff --git a/src/eduid/userdb/tests/test_user.py b/src/eduid/userdb/tests/test_user.py index 975a9b36d..e01ac391d 100644 --- a/src/eduid/userdb/tests/test_user.py +++ b/src/eduid/userdb/tests/test_user.py @@ -22,7 +22,7 @@ __author__ = "ft" -def _keyid(kh): +def _keyid(kh: str): return "sha256:" + sha256(kh.encode("utf-8")).hexdigest() diff --git a/src/eduid/userdb/tests/test_webauthn.py b/src/eduid/userdb/tests/test_webauthn.py index dfb41a39a..3b790be4e 100644 --- a/src/eduid/userdb/tests/test_webauthn.py +++ b/src/eduid/userdb/tests/test_webauthn.py @@ -23,7 +23,7 @@ } -def _keyid(key): +def _keyid(key: dict[str, str]): return "sha256:" + sha256(key["keyhandle"].encode("utf-8") + key["credential_data"].encode("utf-8")).hexdigest() diff --git a/src/eduid/userdb/tou.py b/src/eduid/userdb/tou.py index fd83bd4c9..e3879d48d 100644 --- a/src/eduid/userdb/tou.py +++ b/src/eduid/userdb/tou.py @@ -21,7 +21,7 @@ class ToUEvent(Event): @field_validator("version") @classmethod - def _validate_tou_version(cls, v): + def _validate_tou_version(cls, v: Any): if not v: raise ValueError("ToU must have a version") if not isinstance(v, str): diff --git a/src/eduid/userdb/util.py b/src/eduid/userdb/util.py index df29d27bf..a319a2983 100644 --- a/src/eduid/userdb/util.py +++ b/src/eduid/userdb/util.py @@ -14,13 +14,13 @@ class UTC(datetime.tzinfo): """UTC""" - def utcoffset(self, dt): + def utcoffset(self, dt: datetime.datetime | None): return datetime.timedelta(0) - def tzname(self, dt): + def tzname(self, dt: datetime.datetime | None): return "UTC" - def dst(self, dt): + def dst(self, dt: datetime.datetime | None): return datetime.timedelta(0) diff --git a/src/eduid/vccs/client/__init__.py b/src/eduid/vccs/client/__init__.py index e61aa767a..c249f88bc 100644 --- a/src/eduid/vccs/client/__init__.py +++ b/src/eduid/vccs/client/__init__.py @@ -174,7 +174,7 @@ def _get_random_bytes(self, num_bytes: int) -> bytes: """ return os.urandom(num_bytes) - def to_dict(self, _action): + def to_dict(self, _action: str) -> dict[str, Any]: """ Return factor as dictionary, transmittable to authentiation backends. :param _action: 'auth', 'add_creds' or 'revoke_creds' @@ -193,7 +193,15 @@ class VCCSOathFactor(VCCSFactor): """ def __init__( - self, oath_type, credential_id, user_code=None, nonce=None, aead=None, key_handle=None, digits=6, oath_counter=0 + self, + oath_type: str, + credential_id: int, + user_code: int | None = None, + nonce: str | None = None, + aead: str | None = None, + key_handle: int | None = None, + digits: int = 6, + oath_counter: int = 0, ): """ :param oath_type: 'oath-totp' or 'oath-hotp' (time based or event based OATH) @@ -362,7 +370,7 @@ def revoke_credentials(self, user_id: str, factors: Sequence[VCCSRevokeFactor]) raise TypeError(f"Operation success value type error : {success!r}") return success is True - def _execute(self, data, response_label: str): + def _execute(self, data: str, response_label: str): """ Make a HTTP POST request to the authentication backend, and parse the result. @@ -391,7 +399,7 @@ def _execute(self, data, response_label: str): raise AssertionError(f"Received response of unknown version {resp_ver!r}") return resp[response_label] - def _execute_request_response(self, service: str, values): + def _execute_request_response(self, service: str, values: dict[str, Any]): """ The part of _execute that has actual side effects. In a separate function to make everything else easily testable. diff --git a/src/eduid/vccs/client/tests/test_client.py b/src/eduid/vccs/client/tests/test_client.py index 8cd7138ec..d85154618 100644 --- a/src/eduid/vccs/client/tests/test_client.py +++ b/src/eduid/vccs/client/tests/test_client.py @@ -1,16 +1,15 @@ -#!/usr/bin/python - -""" -Test VCCS client. -""" - import os import unittest +from typing import Any import simplejson as json from eduid.vccs.client import VCCSClient, VCCSOathFactor, VCCSPasswordFactor, VCCSRevokeFactor +""" +Test VCCS client. +""" + class FakeVCCSClient(VCCSClient): """ @@ -18,11 +17,11 @@ class FakeVCCSClient(VCCSClient): in order to fake HTTP communication. """ - def __init__(self, fake_response): + def __init__(self, fake_response: str): self.fake_response = fake_response VCCSClient.__init__(self) - def _execute_request_response(self, service, values): + def _execute_request_response(self, service: str, values: dict[str, Any]): self.last_service = service self.last_values = values return self.fake_response @@ -33,7 +32,7 @@ class FakeVCCSPasswordFactor(VCCSPasswordFactor): Sub-class that overrides the get_random_bytes function to make certain things testable. """ - def _get_random_bytes(self, num_bytes): + def _get_random_bytes(self, num_bytes: int): b = os.urandom(1) if isinstance(b, str): # Python2 diff --git a/src/eduid/vccs/server/endpoints/add_creds.py b/src/eduid/vccs/server/endpoints/add_creds.py index 37a3c7dc0..b7af310c2 100644 --- a/src/eduid/vccs/server/endpoints/add_creds.py +++ b/src/eduid/vccs/server/endpoints/add_creds.py @@ -67,7 +67,9 @@ async def add_creds(req: Request, request: AddCredsRequestV1) -> AddCredsRespons return response -async def _add_password_credential(_config, factor, req, request): +async def _add_password_credential( + _config: VCCSConfig, factor: RequestFactor, req: Request, request: AddCredsRequestV1 +): _salt = (await req.app.state.hasher.safe_random(_config.add_creds_password_salt_bytes)).hex() cred = PasswordCredential( credential_id=factor.credential_id, diff --git a/src/eduid/vccs/server/hasher.py b/src/eduid/vccs/server/hasher.py index 68b460fe7..c87161c4d 100644 --- a/src/eduid/vccs/server/hasher.py +++ b/src/eduid/vccs/server/hasher.py @@ -2,6 +2,7 @@ import os import stat from abc import ABC +from asyncio.locks import Lock from binascii import unhexlify from collections.abc import Mapping from hashlib import sha1 @@ -11,8 +12,23 @@ import yaml +class NoOpLock: + """ + A No-op lock class, to avoid a lot of "if self.lock:" in code using locks. + """ + + def __init__(self): + pass + + async def acquire(self): + pass + + async def release(self): + pass + + class VCCSHasher(ABC): - def __init__(self, lock): + def __init__(self, lock: Lock | NoOpLock): self.lock = lock def unlock(self, password: str) -> None: @@ -21,16 +37,16 @@ def unlock(self, password: str) -> None: def info(self) -> Any: raise NotImplementedError("Subclass should implement info") - def hmac_sha1(self, _key_handle, _data): + async def hmac_sha1(self, key_handle: int | None, data: bytes): raise NotImplementedError("Subclass should implement safe_hmac_sha1") - def unsafe_hmac_sha1(self, _key_handle, _data): + def unsafe_hmac_sha1(self, key_handle: int | None, _data: bytes): raise NotImplementedError("Subclass should implement hmac_sha1") - def load_temp_key(self, _nonce, _key_handle, _aead): + def load_temp_key(self, nonce: str, _key_handle: int, _aead: bytes): raise NotImplementedError("Subclass should implement load_temp_key") - def safe_random(self, _byte_count): + async def safe_random(self, byte_count: int): raise NotImplementedError("Subclass should implement safe_random") async def lock_acquire(self): @@ -41,7 +57,7 @@ async def lock_release(self): class VCCSYHSMHasher(VCCSHasher): - def __init__(self, device, lock, debug=False): + def __init__(self, device: str, lock: Lock | NoOpLock, debug: bool = False): VCCSHasher.__init__(self, lock) self._yhsm = pyhsm.base.YHSM(device, debug) @@ -52,7 +68,7 @@ def unlock(self, password: str) -> None: def info(self) -> Any: return self._yhsm.info() - async def hmac_sha1(self, key_handle: int, data: bytes) -> bytes: + async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: """ Perform HMAC-SHA-1 operation using YubiHSM. @@ -64,12 +80,12 @@ async def hmac_sha1(self, key_handle: int, data: bytes) -> bytes: finally: await self.lock_release() - def unsafe_hmac_sha1(self, key_handle: int, data: bytes) -> bytes: + def unsafe_hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: if key_handle is None: key_handle = pyhsm.defines.YSM_TEMP_KEY_HANDLE return self._yhsm.hmac_sha1(key_handle, data).get_hash() - def load_temp_key(self, nonce, key_handle, aead): + def load_temp_key(self, nonce: str, key_handle: int, aead: bytes): return self._yhsm.load_temp_key(nonce, key_handle, aead) async def safe_random(self, byte_count: int) -> bytes: @@ -94,7 +110,7 @@ class VCCSSoftHasher(VCCSHasher): (except perhaps separating HMAC keys from credential store). """ - def __init__(self, keys: Mapping[int, str], lock, debug=False): + def __init__(self, keys: Mapping[int, str], lock: Lock | NoOpLock, debug: bool = False): super().__init__(lock) self.debug = debug # Covert keys from strings to bytes when loading @@ -142,22 +158,7 @@ async def safe_random(self, byte_count: int) -> bytes: return os.urandom(byte_count) -class NoOpLock: - """ - A No-op lock class, to avoid a lot of "if self.lock:" in code using locks. - """ - - def __init__(self): - pass - - async def acquire(self): - pass - - async def release(self): - pass - - -def hasher_from_string(name: str, lock=None, debug=False): +def hasher_from_string(name: str, lock: Lock | NoOpLock | None = None, debug: bool = False): """ Create a hasher instance from a name. Name can currently only be a name of a YubiHSM device, such as '/dev/ttyACM0'. diff --git a/src/eduid/vccs/server/log.py b/src/eduid/vccs/server/log.py index 6babb08df..dce4acea2 100644 --- a/src/eduid/vccs/server/log.py +++ b/src/eduid/vccs/server/log.py @@ -5,8 +5,9 @@ class InterceptHandler(logging.Handler): - def emit(self, record): + def emit(self, record: logging.LogRecord): # Get corresponding Loguru level if it exists + level: str | int try: level = loguru_logger.level(record.levelname).name except ValueError: diff --git a/src/eduid/vccs/server/run.py b/src/eduid/vccs/server/run.py index 64c3bb935..c1bc04273 100644 --- a/src/eduid/vccs/server/run.py +++ b/src/eduid/vccs/server/run.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from typing import Any -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from ndnkdf import ndnkdf from starlette.responses import JSONResponse @@ -75,7 +75,7 @@ async def startup_event(): @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(request: Request, exc: RequestValidationError): request.app.logger.warning(f"Failed parsing request: {exc}") return JSONResponse({"errors": exc.errors()}, status_code=HTTP_422_UNPROCESSABLE_ENTITY) diff --git a/src/eduid/webapp/authn/tests/test_authn.py b/src/eduid/webapp/authn/tests/test_authn.py index 3074a4ee1..11b7481f2 100644 --- a/src/eduid/webapp/authn/tests/test_authn.py +++ b/src/eduid/webapp/authn/tests/test_authn.py @@ -291,7 +291,8 @@ def test_frontend_state(self): eppn = "hubba-bubba" self.acs("/authenticate", eppn, FrontendAction.REMOVE_SECURITY_KEY_AUTHN, frontend_state="key_id_to_remove") - def _signup_authn_user(self, eppn): + # TODO: up for removal since it seems unused + def _signup_authn_user(self, eppn: str): timestamp = utc_now() with self.app.test_client() as c: diff --git a/src/eduid/webapp/common/api/debug.py b/src/eduid/webapp/common/api/debug.py index a881e7d88..005cec47b 100644 --- a/src/eduid/webapp/common/api/debug.py +++ b/src/eduid/webapp/common/api/debug.py @@ -1,13 +1,16 @@ import pprint import sys import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable from dataclasses import asdict from typing import Any from urllib import parse from flask import Flask, url_for +# TODO: in python >= 3.11 import from wsgiref.types +from eduid.webapp.common.wsgi import StartResponse, WSGIEnvironment + __author__ = "lundberg" @@ -15,13 +18,13 @@ class LoggingMiddleware: def __init__(self, app: Callable[..., Any]): self._app = app - def __call__(self, environ: dict[Any, Any], resp: Callable[..., Any]): + def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> Iterable[bytes]: errorlog = environ["wsgi.errors"] pprint.pprint(("REQUEST", environ), stream=errorlog) - def log_response(status, headers, *args): + def log_response(status: str, headers: list[tuple[str, str]], *args): pprint.pprint(("RESPONSE", status, headers), stream=errorlog) - return resp(status, headers, *args) + return start_response(status, headers, *args) return self._app(environ, log_response) @@ -55,7 +58,7 @@ def dump_config(app: Flask): def init_app_debug(app: Flask): - app.wsgi_app = LoggingMiddleware(app.wsgi_app) # type: ignore[assignment] + app.wsgi_app = LoggingMiddleware(app.wsgi_app) # type: ignore[method-assign] dump_config(app) log_endpoints(app) pprint.pprint(("view_functions", app.view_functions), stream=sys.stderr) diff --git a/src/eduid/webapp/common/api/exceptions.py b/src/eduid/webapp/common/api/exceptions.py index 5e0d11b2a..60385395f 100644 --- a/src/eduid/webapp/common/api/exceptions.py +++ b/src/eduid/webapp/common/api/exceptions.py @@ -1,7 +1,8 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from flask import jsonify +from flask import Flask, jsonify +from werkzeug.exceptions import HTTPException __author__ = "lundberg" @@ -72,10 +73,10 @@ def __init__(self, state: "ResetPasswordEmailState"): self.state = state -def init_exception_handlers(app): +def init_exception_handlers(app: Flask): # Init error handler for raised exceptions @app.errorhandler(400) - def _handle_flask_http_exception(error): + def _handle_flask_http_exception(error: HTTPException): app.logger.error(f"HttpException {error!s}") e = ApiException(error.name, error.code) if app.config.get("DEBUG"): @@ -87,7 +88,7 @@ def _handle_flask_http_exception(error): return app -def init_sentry(app): +def init_sentry(app: Flask): if app.config.get("SENTRY_DSN"): try: from raven.contrib.flask import Sentry diff --git a/src/eduid/webapp/common/api/middleware.py b/src/eduid/webapp/common/api/middleware.py index 9929ae01c..98622bb59 100644 --- a/src/eduid/webapp/common/api/middleware.py +++ b/src/eduid/webapp/common/api/middleware.py @@ -1,9 +1,15 @@ __author__ = "lundberg" +from collections.abc import Callable +from typing import Any + +# TODO: in python >= 3.11 import from wsgiref.types +from eduid.webapp.common.wsgi import StartResponse, WSGIEnvironment + # Copied from https://stackoverflow.com/questions/18967441/add-a-prefix-to-all-flask-routes/36033627#36033627 class PrefixMiddleware: - def __init__(self, app, prefix="", server_name=""): + def __init__(self, app: Callable[..., Any], prefix: str = "", server_name: str = ""): self.app = app if prefix is None: prefix = "" @@ -12,7 +18,7 @@ def __init__(self, app, prefix="", server_name=""): self.prefix = prefix self.server_name = server_name - def __call__(self, environ, start_response): + def __call__(self, environ: WSGIEnvironment, start_response: StartResponse): # Handle localhost requests for health checks if environ.get("REMOTE_ADDR") == "127.0.0.1": environ["HTTP_HOST"] = self.server_name diff --git a/src/eduid/webapp/common/api/request.py b/src/eduid/webapp/common/api/request.py index 719e4c55e..91b8f9e16 100644 --- a/src/eduid/webapp/common/api/request.py +++ b/src/eduid/webapp/common/api/request.py @@ -15,6 +15,7 @@ """ import logging +from collections.abc import Callable from typing import Any, AnyStr from flask import abort @@ -53,7 +54,7 @@ class SanitizedImmutableMultiDict(ImmutableMultiDict, SanitationMixin): sanitize the extracted data. """ - def __getitem__(self, key): + def __getitem__(self, key: Any): """ Return the first data value for this key; raises KeyError if not found. @@ -64,7 +65,7 @@ def __getitem__(self, key): value = super().__getitem__(key) return self.sanitize_input(value) - def getlist(self, key, type=None): + def getlist(self, key: Any, type: Callable[[Any], Any] | None = None): """ Return the list of items for a given key. If that key is not in the `MultiDict`, the return value will be an empty list. Just as `get` @@ -77,10 +78,11 @@ def getlist(self, key, type=None): by this callable the value will be removed from the list. :return: a :class:`list` of all the values for the key. """ + assert type is not None value_list = super().getlist(key, type=type) return [self.sanitize_input(v) for v in value_list] - def items(self, multi=False): + def items(self, multi: bool = False): """ Return an iterator of ``(key, value)`` pairs. @@ -124,7 +126,7 @@ def listvalues(self): for values in dict.values(self): yield (self.sanitize_input(v) for v in values) - def to_dict(self, flat=True): + def to_dict(self, flat: bool = True): """Return the contents as regular dict. If `flat` is `True` the returned dict will only have the first item present, if `flat` is `False` all values will be returned as lists. @@ -150,14 +152,14 @@ class SanitizedTypeConversionDict(ImmutableTypeConversionDict, SanitationMixin): sanitize the extracted data. """ - def __getitem__(self, key): + def __getitem__(self, key: Any): """ Sanitized __getitem__ """ val = super(ImmutableTypeConversionDict, self).__getitem__(key) return self.sanitize_input(str(val)) - def get(self, key, default=None, type=None) -> Any | None: # type: ignore[override] + def get(self, key: str, default: str | None = None, type: type | None = None) -> Any | None: # type: ignore[override] """ Sanitized, type conversion get. The value identified by `key` is sanitized, and if `type` diff --git a/src/eduid/webapp/common/api/schemas/csrf.py b/src/eduid/webapp/common/api/schemas/csrf.py index 59452a5e0..d8c1c3436 100644 --- a/src/eduid/webapp/common/api/schemas/csrf.py +++ b/src/eduid/webapp/common/api/schemas/csrf.py @@ -1,4 +1,5 @@ import logging +from typing import Any from flask import current_app, request from marshmallow import Schema, ValidationError, fields, post_load, pre_dump, validates @@ -15,7 +16,7 @@ class CSRFRequestMixin(Schema): csrf_token = fields.String(required=True) @validates("csrf_token") - def validate_csrf_token(self, value, **kwargs): + def validate_csrf_token(self, value: str, **kwargs): custom_header = request.headers.get("X-Requested-With") if custom_header != "XMLHttpRequest": # TODO: move value to config current_app.logger.error("CSRF check: missing custom X-Requested-With header") @@ -25,13 +26,13 @@ def validate_csrf_token(self, value, **kwargs): logger.debug(f"Validated CSRF token in session: {session.get_csrf_token()}") @post_load - def post_processing(self, in_data, **kwargs): + def post_processing(self, in_data: Any, **kwargs): # Remove token from data forwarded to views in_data = self.remove_csrf_token(in_data) return in_data @staticmethod - def remove_csrf_token(in_data, **kwargs): + def remove_csrf_token(in_data: Any, **kwargs): del in_data["csrf_token"] return in_data @@ -40,7 +41,7 @@ class CSRFResponseMixin(Schema): csrf_token = fields.String(required=True) @pre_dump - def get_csrf_token(self, out_data, **kwargs): + def get_csrf_token(self, out_data: Any, **kwargs): # Generate a new csrf token for every response out_data["csrf_token"] = session.new_csrf_token() logger.debug(f'Generated new CSRF token in CSRFResponseMixin: {out_data["csrf_token"]}') diff --git a/src/eduid/webapp/common/api/schemas/email.py b/src/eduid/webapp/common/api/schemas/email.py index 5ef0ba85c..7dcf341ce 100644 --- a/src/eduid/webapp/common/api/schemas/email.py +++ b/src/eduid/webapp/common/api/schemas/email.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow.fields import Email __author__ = "lundberg" @@ -8,10 +10,12 @@ class LowercaseEmail(Email): Email field that serializes and deserializes to a lower case string. """ - def _serialize(self, value, attr, obj, **kwargs): - value = super()._serialize(value, attr, obj, **kwargs) - return value.lower() + def _serialize(self, value: str | bytes, attr: Any, obj: Any, **kwargs): + _value = super()._serialize(value, attr, obj, **kwargs) + if _value is None: + return None + return _value.lower() - def _deserialize(self, value, attr, data, **kwargs): - value = super()._deserialize(value, attr, data, **kwargs) - return value.lower() + def _deserialize(self, value: str | bytes, attr: Any, data: Any, **kwargs): + _value = super()._deserialize(value, attr, data, **kwargs) + return _value.lower() diff --git a/src/eduid/webapp/common/api/schemas/models.py b/src/eduid/webapp/common/api/schemas/models.py index ca083eea8..95f543432 100644 --- a/src/eduid/webapp/common/api/schemas/models.py +++ b/src/eduid/webapp/common/api/schemas/models.py @@ -1,6 +1,9 @@ +from collections.abc import Mapping from enum import Enum, unique from typing import Any +from flask import Request + from eduid.webapp.common.api.utils import get_flux_type __author__ = "lundberg" @@ -38,7 +41,13 @@ class FluxResponse: An action MUST NOT include properties other than type, payload, error, and meta. """ - def __init__(self, req, payload=None, error=None, meta=None): + def __init__( + self, + req: Request, + payload: Mapping[str, Any] | None = None, + error: bool | None = None, + meta: Mapping[str, Any] | None = None, + ): _suffix = "success" if error: _suffix = "fail" @@ -57,7 +66,7 @@ def __str__(self): return f"{self.__class__.__name__!s} ({self.to_dict()!r})" def to_dict(self) -> dict[str, Any]: - rv = dict() + rv = dict[str, Any]() # A Flux Standard Action MUST have a type rv["type"] = self.flux_type # ... and MAY have payload, error, meta (and MUST NOT have anything else) @@ -75,10 +84,10 @@ def to_dict(self) -> dict[str, Any]: class FluxSuccessResponse(FluxResponse): - def __init__(self, req, payload, meta=None): + def __init__(self, req: Request, payload: Mapping[str, Any] | None, meta: Mapping[str, Any] | None = None): super().__init__(req, payload, meta=meta) class FluxFailResponse(FluxResponse): - def __init__(self, req, payload, meta=None): + def __init__(self, req: Request, payload: Mapping[str, Any] | None, meta: Mapping[str, Any] | None = None): super().__init__(req, payload, error=True, meta=meta) diff --git a/src/eduid/webapp/common/api/schemas/validators.py b/src/eduid/webapp/common/api/schemas/validators.py index ec33f5fef..79a124340 100644 --- a/src/eduid/webapp/common/api/schemas/validators.py +++ b/src/eduid/webapp/common/api/schemas/validators.py @@ -5,7 +5,7 @@ __author__ = "lundberg" -def validate_nin(nin, **kwargs): +def validate_nin(nin: str, **kwargs): """ :param nin: National Identity Number :type nin: string_types @@ -18,7 +18,7 @@ def validate_nin(nin, **kwargs): raise ValidationError("nin needs to be formatted as 18|19|20yymmddxxxx") -def validate_email(email, **kwargs): +def validate_email(email: str, **kwargs): """ :param email: E-mail address :type email: string_types diff --git a/src/eduid/webapp/common/api/tests/test_inputs.py b/src/eduid/webapp/common/api/tests/test_inputs.py index 1185d6860..f86feb046 100644 --- a/src/eduid/webapp/common/api/tests/test_inputs.py +++ b/src/eduid/webapp/common/api/tests/test_inputs.py @@ -21,7 +21,7 @@ __author__ = "lundberg" -def dont_validate(value): +def dont_validate(value: Any): raise ValidationError(f"Problem with {value!r}") @@ -35,7 +35,7 @@ class Meta: test_views = Blueprint("test", __name__) -def _make_response(data): +def _make_response(data: str): html = f"{data}" response = make_response(html, 200) response.headers["Content-Type"] = "text/html; charset=utf8" @@ -56,7 +56,7 @@ def post_param_view(): @test_views.route("/test-post-json", methods=["POST"]) @UnmarshalWith(NonValidatingSchema) -def post_json_view(test_data): +def post_json_view(test_data: str): """never validates""" pass diff --git a/src/eduid/webapp/common/api/validation.py b/src/eduid/webapp/common/api/validation.py index ca5a7816f..3be1f7928 100644 --- a/src/eduid/webapp/common/api/validation.py +++ b/src/eduid/webapp/common/api/validation.py @@ -24,7 +24,7 @@ def is_valid_nin(nin: str) -> bool: raise ValueError("nin needs to be formatted as 18|19|20yymmddxxxx") -def is_valid_email(email, **kwargs): +def is_valid_email(email: str, **kwargs): """ :param email: E-mail address :return: True or raises ValueError diff --git a/src/eduid/webapp/common/authn/middleware.py b/src/eduid/webapp/common/authn/middleware.py index 27ce11388..0788e4b0c 100644 --- a/src/eduid/webapp/common/authn/middleware.py +++ b/src/eduid/webapp/common/authn/middleware.py @@ -3,7 +3,7 @@ import re from abc import ABCMeta from collections.abc import Iterable, Mapping -from typing import TYPE_CHECKING, Any, cast +from typing import Any, cast from urllib.parse import urlparse from flask import Request, current_app @@ -17,8 +17,8 @@ from eduid.webapp.common.session import session from eduid.webapp.common.session.redis_session import NoSessionDataFoundException -if TYPE_CHECKING: - from _typeshed.wsgi import StartResponse, WSGIEnvironment +# TODO: in python >= 3.11 import from wsgiref.types +from eduid.webapp.common.wsgi import StartResponse, WSGIEnvironment no_context_logger = logging.getLogger(__name__) diff --git a/src/eduid/webapp/common/authn/testing.py b/src/eduid/webapp/common/authn/testing.py index 10eca49cb..c30a3eaf4 100644 --- a/src/eduid/webapp/common/authn/testing.py +++ b/src/eduid/webapp/common/authn/testing.py @@ -1,38 +1,13 @@ -import json import logging +from collections.abc import Sequence from eduid.common.decorators import deprecated -from eduid.vccs.client import VCCSClient +from eduid.vccs.client import VCCSFactor, VCCSPasswordFactor, VCCSRevokeFactor logger = logging.getLogger() -class FakeVCCSClient(VCCSClient): - def __init__(self, fake_response=None): - super().__init__() - self.fake_response = fake_response - - def _execute_request_response(self, _service, _values): - if self.fake_response is not None: - return json.dumps(self.fake_response) - - fake_response = {} - if _service == "add_creds": - fake_response = { - "add_creds_response": {"version": 1, "success": True}, - } - elif _service == "authenticate": - fake_response = { - "auth_response": {"version": 1, "authenticated": True}, - } - elif _service == "revoke_creds": - fake_response = { - "revoke_creds_response": {"version": 1, "success": True}, - } - return json.dumps(fake_response) - - -class TestVCCSClient: +class MockVCCSClient: """ Mock VCCS client for testing. It stores factors locally, and it only checks for the credential_id to authenticate/revoke. @@ -45,10 +20,11 @@ class TestVCCSClient: def __init__(self): self.factors = {} - def authenticate(self, user_id, factors): + # TODO: check for removal, seems to be unused + def authenticate(self, user_id: str, factors: Sequence[VCCSPasswordFactor]) -> bool: found = False if user_id not in self.factors: - logger.debug(f"User {user_id!r} not found in TestVCCSClient credential store:\n{self.factors}") + logger.debug(f"User {user_id!r} not found in MockVCCSClient credential store:\n{self.factors}") return False for factor in factors: logger.debug(f"Trying to authenticate user {user_id} with factor {factor} (id {factor.credential_id})") @@ -70,17 +46,18 @@ def authenticate(self, user_id, factors): found = True break logger.debug("Hash {} did not match the expected hash {}".format(fdict["H1"], sdict["H1"])) - logger.debug(f"TestVCCSClient authenticate result for user_id {user_id}: {found}") + logger.debug(f"MockVCCSClient authenticate result for user_id {user_id}: {found}") return found - def add_credentials(self, user_id, factors): - user_factors = self.factors.get(str(user_id), []) + def add_credentials(self, user_id: str, factors: Sequence[VCCSFactor]) -> bool: + user_factors: list = self.factors.get(str(user_id), []) user_factors.extend(factors) self.factors[str(user_id)] = user_factors return True - def revoke_credentials(self, user_id, revoked): - stored = self.factors.get(user_id, None) + def revoke_credentials(self, user_id: str, revoked: Sequence[VCCSRevokeFactor]) -> bool: + stored: list = self.factors.get(user_id, None) + removed: bool = False if stored: # Nothing stored in test client yet for rfactor in revoked: rdict = rfactor.to_dict("revoke_creds") @@ -88,8 +65,6 @@ def revoke_credentials(self, user_id, revoked): fdict = factor.to_dict("revoke_creds") if rdict["credential_id"] == fdict["credential_id"]: stored.remove(factor) + removed = True break - - -# new name to import from dependent packages, so we can remove the deprecated TestVCCSClient -MockVCCSClient = TestVCCSClient + return removed diff --git a/src/eduid/webapp/common/authn/tests/test_fido_tokens.py b/src/eduid/webapp/common/authn/tests/test_fido_tokens.py index f002adfd3..230536b74 100644 --- a/src/eduid/webapp/common/authn/tests/test_fido_tokens.py +++ b/src/eduid/webapp/common/authn/tests/test_fido_tokens.py @@ -179,7 +179,7 @@ def test_webauthn_verify(self, mock_verify: MagicMock): self.assertEqual(resp_data["success"], True) @patch("fido2.cose.ES256.verify") - def test_webauthn_verify_wrong_origin(self, mock_verify): + def test_webauthn_verify_wrong_origin(self, mock_verify: MagicMock): self.app.conf.fido2_rp_id = "wrong.rp.id" mock_verify.return_value = True # Add a working U2F credential for this test @@ -204,7 +204,7 @@ def test_webauthn_verify_wrong_origin(self, mock_verify): self.assertEqual(resp_data["success"], False) @patch("fido2.cose.ES256.verify") - def test_webauthn_verify_wrong_challenge(self, mock_verify): + def test_webauthn_verify_wrong_challenge(self, mock_verify: MagicMock): mock_verify.return_value = True # Add a working U2F credential for this test self.test_user.credentials.add(self.webauthn_credential) @@ -226,7 +226,7 @@ def test_webauthn_verify_wrong_challenge(self, mock_verify): self.assertEqual(resp_data["success"], False) @patch("fido2.cose.ES256.verify") - def test_webauthn_verify_wrong_credential(self, mock_verify): + def test_webauthn_verify_wrong_credential(self, mock_verify: MagicMock): req = deepcopy(SAMPLE_WEBAUTHN_REQUEST) req["credentialId"] = req["credentialId"].replace("0", "9") mock_verify.return_value = True diff --git a/src/eduid/webapp/common/authn/tests/test_middleware.py b/src/eduid/webapp/common/authn/tests/test_middleware.py index 99c0266b0..4ee84080d 100644 --- a/src/eduid/webapp/common/authn/tests/test_middleware.py +++ b/src/eduid/webapp/common/authn/tests/test_middleware.py @@ -18,14 +18,14 @@ def __init__(self, name: str, test_config: Mapping[str, Any], **kwargs): class AuthnTests(EduidAPITestCase): - def load_app(self, config): + def load_app(self, config: dict[str, Any]): """ Called from the parent class, so we can provide the appropriate flask app for this test case. """ return AuthnTestApp("testing", config) - def update_config(self, config): + def update_config(self, config: dict[str, Any]): config.update( { "available_languages": {"en": "English", "sv": "Svenska"}, @@ -47,14 +47,14 @@ def test_get_view(self): class UnAuthnTests(EduidAPITestCase): - def load_app(self, config): + def load_app(self, config: dict[str, Any]): """ Called from the parent class, so we can provide the appropriate flask app for this test case. """ return AuthnTestApp("testing", config) - def update_config(self, config): + def update_config(self, config: dict[str, Any]): config.update( { "available_languages": {"en": "English", "sv": "Svenska"}, diff --git a/src/eduid/webapp/common/authn/tests/test_vccs.py b/src/eduid/webapp/common/authn/tests/test_vccs.py index 2b70aaa65..a6807d3a1 100644 --- a/src/eduid/webapp/common/authn/tests/test_vccs.py +++ b/src/eduid/webapp/common/authn/tests/test_vccs.py @@ -28,7 +28,7 @@ def tearDown(self): vccs_module.revoke_passwords(self.user, reason="testing", application="test", vccs=self.vccs_client) super().tearDown() - def _check_credentials(self, creds): + def _check_credentials(self, creds: str): return vccs_module.check_password(creds, self.user, vccs=self.vccs_client) def test_check_good_credentials(self): diff --git a/src/eduid/webapp/common/authn/utils.py b/src/eduid/webapp/common/authn/utils.py index 43c8b887f..968425e2a 100644 --- a/src/eduid/webapp/common/authn/utils.py +++ b/src/eduid/webapp/common/authn/utils.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -def get_saml2_config(module_path: str, name="SAML_CONFIG") -> SPConfig: +def get_saml2_config(module_path: str, name: str = "SAML_CONFIG") -> SPConfig: """Load SAML2 config file, in the form of a Python module.""" spec = importlib.util.spec_from_file_location("saml2_settings", module_path) if spec is None: diff --git a/src/eduid/webapp/common/authn/vccs.py b/src/eduid/webapp/common/authn/vccs.py index 9afd89446..dff1d9b95 100644 --- a/src/eduid/webapp/common/authn/vccs.py +++ b/src/eduid/webapp/common/authn/vccs.py @@ -340,7 +340,7 @@ def revoke_passwords( @deprecated def revoke_all_credentials( - user, source="dashboard", vccs_url: str | None = None, vccs: VCCSClient | None = None + user: User, source: str = "dashboard", vccs_url: str | None = None, vccs: VCCSClient | None = None ) -> None: if vccs is None: vccs = get_vccs_client(vccs_url) diff --git a/src/eduid/webapp/common/proofing/saml_helpers.py b/src/eduid/webapp/common/proofing/saml_helpers.py index a62067102..131d5d4bc 100644 --- a/src/eduid/webapp/common/proofing/saml_helpers.py +++ b/src/eduid/webapp/common/proofing/saml_helpers.py @@ -1,5 +1,6 @@ import logging +from saml2.config import SPConfig from saml2.metadata import entity_descriptor from eduid.common.misc.timeutil import utc_now @@ -65,5 +66,5 @@ def is_valid_authn_instant(session_info: SessionInfo, max_age: int = 60) -> bool return False -def create_metadata(config): +def create_metadata(config: SPConfig): return entity_descriptor(config) diff --git a/src/eduid/webapp/common/session/eduid_session.py b/src/eduid/webapp/common/session/eduid_session.py index 6f635325f..4ea612476 100644 --- a/src/eduid/webapp/common/session/eduid_session.py +++ b/src/eduid/webapp/common/session/eduid_session.py @@ -127,7 +127,7 @@ def __str__(self): f"modified={self.modified}, cookie={self.short_id}>" ) - def __getitem__(self, key): + def __getitem__(self, key: str): return self._session.__getitem__(key) def __setitem__(self, key: str, value: Any): @@ -136,7 +136,7 @@ def __setitem__(self, key: str, value: Any): logger.debug(f"SET {self}[{key}] = {value}") self.modified = True - def __delitem__(self, key): + def __delitem__(self, key: str): if key in self._session: del self._session[key] logger.debug(f"DEL {self}[{key}]") @@ -148,7 +148,7 @@ def __iter__(self): def __len__(self): return len(self._session) - def __contains__(self, key): + def __contains__(self, key: object): return self._session.__contains__(key) @property @@ -161,7 +161,7 @@ def permanent(self): return True @permanent.setter - def permanent(self, value): + def permanent(self, value: bool): # EduidSessions are _always_ permanent pass diff --git a/src/eduid/webapp/common/session/redis_session.py b/src/eduid/webapp/common/session/redis_session.py index 5cf2ec2b5..8efedad9e 100644 --- a/src/eduid/webapp/common/session/redis_session.py +++ b/src/eduid/webapp/common/session/redis_session.py @@ -206,19 +206,19 @@ def __str__(self): # Include hex(id(self)) for now to troubleshoot clobbered sessions return f"<{self.__class__.__name__} at {hex(id(self))}: db_key={self.short_id}>" - def __getitem__(self, key): + def __getitem__(self, key: str): if key in self._data: return self._data[key] raise KeyError(f"Key {repr(key)} not present in session") - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any): if self.whitelist and key not in self.whitelist: if self.raise_on_unknown: raise ValueError(f"Key {repr(key)} not allowed in session") return self._data[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str): del self._data[key] def __iter__(self): @@ -227,7 +227,7 @@ def __iter__(self): def __len__(self): return len(self._data) - def __contains__(self, key): + def __contains__(self, key: object): return self._data.__contains__(key) @property diff --git a/src/eduid/webapp/common/session/tests/test_eduid_session.py b/src/eduid/webapp/common/session/tests/test_eduid_session.py index 59d81c72f..b6e49c76e 100644 --- a/src/eduid/webapp/common/session/tests/test_eduid_session.py +++ b/src/eduid/webapp/common/session/tests/test_eduid_session.py @@ -23,7 +23,7 @@ def __init__(self, config: SessionTestConfig, **kwargs): self.conf = config -def session_init_app(name, test_config: Mapping[str, Any]) -> SessionTestApp: +def session_init_app(name: str, test_config: Mapping[str, Any]) -> SessionTestApp: config = load_config(typ=SessionTestConfig, app_name=name, ns="webapp", test_config=test_config) app = SessionTestApp(config, init_central_userdb=False) no_authn_views(config, ["/unauthenticated"]) @@ -195,7 +195,7 @@ def test_remove_cookie_on_invalidated_session_save(self): elif value == "expires": self.assertEqual("Thu, 01-Jan-1970 00:00:00 GMT", value) - def _test_bad_session_cookie(self, bad_cookie_value): + def _test_bad_session_cookie(self, bad_cookie_value: str): with self.browser as browser: browser.set_cookie(domain=".test.localhost", key="sessid", value=bad_cookie_value) response = browser.get("/unauthenticated") diff --git a/src/eduid/webapp/common/wsgi.py b/src/eduid/webapp/common/wsgi.py new file mode 100644 index 000000000..91e1de5f7 --- /dev/null +++ b/src/eduid/webapp/common/wsgi.py @@ -0,0 +1,15 @@ +from collections.abc import Callable +from types import TracebackType +from typing import Any, Protocol, TypeAlias + +ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType] +OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None] + + +class StartResponse(Protocol): + def __call__( + self, status: str, headers: list[tuple[str, str]], exc_info: OptExcInfo | None = ..., / + ) -> Callable[[bytes], object]: ... + + +WSGIEnvironment: TypeAlias = dict[str, Any] # stable diff --git a/src/eduid/webapp/email/validators.py b/src/eduid/webapp/email/validators.py index 6d3c201ff..2e8eee839 100644 --- a/src/eduid/webapp/email/validators.py +++ b/src/eduid/webapp/email/validators.py @@ -4,14 +4,14 @@ from eduid.webapp.email.helpers import EmailMsg -def email_exists(email): +def email_exists(email: str): user = get_user() user_emails = [e.email for e in user.mail_addresses.to_list()] if email not in user_emails: raise ValidationError(EmailMsg.missing.value) -def email_does_not_exist(email): +def email_does_not_exist(email: str): user = get_user() user_emails = [e.email for e in user.mail_addresses.to_list()] if email in user_emails: diff --git a/src/eduid/webapp/email/verifications.py b/src/eduid/webapp/email/verifications.py index 08d6ff95d..035959c4c 100644 --- a/src/eduid/webapp/email/verifications.py +++ b/src/eduid/webapp/email/verifications.py @@ -6,6 +6,7 @@ from eduid.userdb.logs import MailAddressProofing from eduid.userdb.mail import MailAddress from eduid.userdb.proofing import EmailProofingElement, EmailProofingState +from eduid.userdb.proofing.user import ProofingUser from eduid.webapp.common.api.translation import get_user_locale from eduid.webapp.common.api.utils import save_and_sync_user from eduid.webapp.email.app import current_email_app as current_app @@ -63,7 +64,7 @@ def send_verification_code(email: str, user: User) -> bool: return True -def verify_mail_address(state, proofing_user): +def verify_mail_address(state: EmailProofingState, proofing_user: ProofingUser): """ :param proofing_user: ProofingUser :param state: E-mail proofing state @@ -81,6 +82,9 @@ def verify_mail_address(state, proofing_user): # Adding the phone to the list creates a copy of the element, so we have to 'find' it again email = proofing_user.mail_addresses.find(state.verification.email) + # please mypy, email should be set now + assert email + email.is_verified = True if not proofing_user.mail_addresses.primary: email.is_primary = True diff --git a/src/eduid/webapp/email/views.py b/src/eduid/webapp/email/views.py index acb77fa43..3bdcf2899 100644 --- a/src/eduid/webapp/email/views.py +++ b/src/eduid/webapp/email/views.py @@ -1,6 +1,7 @@ from flask import Blueprint, abort, request from marshmallow import ValidationError +from eduid.userdb.element import ElementKey from eduid.userdb.exceptions import UserOutOfSync from eduid.userdb.mail import MailAddress from eduid.userdb.proofing import ProofingUser @@ -38,7 +39,7 @@ def get_all_emails(user: User) -> FluxData: @UnmarshalWith(AddEmailSchema) @MarshalWith(EmailResponseSchema) @require_user -def post_email(user: User, email: str, verified, primary) -> FluxData: +def post_email(user: User, email: str, verified: bool, primary: bool) -> FluxData: proofing_user = ProofingUser.from_user(user, current_app.private_userdb) current_app.logger.debug(f"Trying to save unconfirmed email {repr(email)} for user {proofing_user}") @@ -148,7 +149,7 @@ def verify(user: User, code: str, email: str) -> FluxData: @UnmarshalWith(ChangeEmailSchema) @MarshalWith(EmailResponseSchema) @require_user -def post_remove(user, email): +def post_remove(user: User, email: ElementKey): proofing_user = ProofingUser.from_user(user, current_app.private_userdb) current_app.logger.debug(f"Trying to remove email address {email!r} from user {proofing_user}") diff --git a/src/eduid/webapp/freja_eid/views.py b/src/eduid/webapp/freja_eid/views.py index de9f23122..f39f09ba5 100644 --- a/src/eduid/webapp/freja_eid/views.py +++ b/src/eduid/webapp/freja_eid/views.py @@ -129,7 +129,7 @@ def _authn( @freja_eid_views.route("/authn-callback", methods=["GET"]) @require_user -def authn_callback(user) -> WerkzeugResponse: +def authn_callback(user: User) -> WerkzeugResponse: """ This is the callback endpoint for the Svipe ID OIDC flow. """ diff --git a/src/eduid/webapp/idp/decorators.py b/src/eduid/webapp/idp/decorators.py index b07e82bbe..8c6736ba9 100644 --- a/src/eduid/webapp/idp/decorators.py +++ b/src/eduid/webapp/idp/decorators.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Callable from functools import wraps from flask import jsonify, request @@ -15,7 +16,7 @@ logger = logging.getLogger(__name__) -def require_ticket(f): +def require_ticket(f: Callable): @wraps(f) def require_ticket_decorator(*args, **kwargs): """Decorator to turn the 'ref' parameter sent by the frontend into a ticket (LoginContext)""" @@ -54,7 +55,7 @@ def require_ticket_decorator(*args, **kwargs): return require_ticket_decorator -def uses_sso_session(f): +def uses_sso_session(f: Callable): @wraps(f) def uses_sso_session_decorator(*args, **kwargs): """Decorator to supply the current SSO session, if one is found and still valid""" diff --git a/src/eduid/webapp/idp/settings/common.py b/src/eduid/webapp/idp/settings/common.py index 00cf60c9d..24550b2f9 100644 --- a/src/eduid/webapp/idp/settings/common.py +++ b/src/eduid/webapp/idp/settings/common.py @@ -3,6 +3,7 @@ """ from datetime import timedelta +from typing import Any from pydantic import Field, HttpUrl, field_validator from pydantic_core.core_schema import ValidationInfo @@ -157,7 +158,7 @@ class IdPConfig(EduIDBaseAppConfig, TouConfigMixin, WebauthnConfigMixin2, AmConf @field_validator("sso_cookie") @classmethod - def make_sso_cookie(cls, v, info: ValidationInfo) -> CookieConfig: + def make_sso_cookie(cls, v: Any, info: ValidationInfo) -> CookieConfig: # Convert sso_cookie from dict to the proper dataclass if isinstance(v, dict): return CookieConfig(**v) @@ -170,7 +171,7 @@ def make_sso_cookie(cls, v, info: ValidationInfo) -> CookieConfig: @field_validator("sso_session_lifetime", mode="before") @classmethod - def validate_sso_session_lifetime(cls, v): + def validate_sso_session_lifetime(cls, v: Any): if isinstance(v, int): # legacy format for this was number of minutes v = v * 60 diff --git a/src/eduid/webapp/idp/tests/test_SSO.py b/src/eduid/webapp/idp/tests/test_SSO.py index d09f66b1b..6b0d3879f 100644 --- a/src/eduid/webapp/idp/tests/test_SSO.py +++ b/src/eduid/webapp/idp/tests/test_SSO.py @@ -7,8 +7,6 @@ import saml2.server import saml2.time_util from saml2 import BINDING_HTTP_POST -from saml2.s_utils import UnravelError -from werkzeug.exceptions import BadRequest from eduid.common.misc.timeutil import utc_now from eduid.common.models.saml2 import EduidAuthnContextClass @@ -72,11 +70,15 @@ def make_SAML_request(class_ref: EduidAuthnContextClass | str | None = None): return _transport_encode(xml) -def _transport_encode(data): +def _transport_encode(data: str): # encode('base64') only works for POST bindings, redirect uses zlib compression too. return b64encode("".join(data.split("\n"))) +class SAMLError(BaseException): + pass + + class SSOIdPTests(IdPAPITests): def _make_login_ticket( self, @@ -103,11 +105,9 @@ def _parse_SAMLRequest( self, info: Mapping, binding: str, - logger: logging.Logger, idp: saml2.server.Server, - bad_request, debug: bool = False, - verify_request_signatures=True, + verify_request_signatures: bool = True, ) -> IdP_SAMLRequest: """ Parse a SAMLRequest query parameter (base64 encoded) into an AuthnRequest @@ -123,15 +123,13 @@ def _parse_SAMLRequest( """ try: saml_req = IdP_SAMLRequest(info["SAMLRequest"], binding, idp, debug=debug) - except UnravelError: - raise bad_request("No valid SAMLRequest found", logger=logger) - except ValueError: - raise bad_request("No valid SAMLRequest found", logger=logger) + except Exception: + raise SAMLError("No valid SAMLRequest found") if "SigAlg" in info and "Signature" in info: # Signed request if verify_request_signatures: if not saml_req.verify_signature(info["SigAlg"], info["Signature"]): - raise bad_request("SAML request signature verification failure", logger=logger) + raise SAMLError("SAML request signature verification failure") else: logger.debug("Ignoring existing request signature, verify_request_signature is False") else: @@ -903,8 +901,6 @@ def test_forceauthn_request(self): x = self._parse_SAMLRequest( info, binding=BINDING_HTTP_POST, - bad_request=BadRequest, - logger=logger, idp=self.app.IDP, debug=True, verify_request_signatures=False, diff --git a/src/eduid/webapp/idp/tests/test_api.py b/src/eduid/webapp/idp/tests/test_api.py index 8630e48cc..d5c92cf9b 100644 --- a/src/eduid/webapp/idp/tests/test_api.py +++ b/src/eduid/webapp/idp/tests/test_api.py @@ -347,7 +347,7 @@ def _extract_path_from_info(self, info: Mapping[str, Any]) -> str: loc = _location_headers[0][1] return self._extract_path_from_url(loc) - def _extract_path_from_url(self, url): + def _extract_path_from_url(self, url: str): # It is a complete URL, extract the path from it (8 is to skip over slashes in https://) _idx = url[8:].index("/") path = url[8 + _idx :] @@ -449,7 +449,7 @@ def add_test_user_external_mfa_cred( user.credentials.add(cred) self.request_user_sync(user) - def get_attributes(self, result, saml2_client: Saml2Client | None = None): + def get_attributes(self, result: LoginResultAPI, saml2_client: Saml2Client | None = None): assert result.finished_result is not None authn_response = self.parse_saml_authn_response(result.finished_result, saml2_client=saml2_client) session_info = authn_response.session_info() diff --git a/src/eduid/webapp/idp/tests/test_idPUserDb.py b/src/eduid/webapp/idp/tests/test_idPUserDb.py index c458f7c3a..2f180be58 100644 --- a/src/eduid/webapp/idp/tests/test_idPUserDb.py +++ b/src/eduid/webapp/idp/tests/test_idPUserDb.py @@ -3,12 +3,14 @@ import datetime import logging -from unittest.mock import patch +from unittest.mock import MagicMock, patch from bson import ObjectId import eduid.userdb import eduid.webapp.common.authn +from eduid.userdb.credentials.password import Password +from eduid.userdb.mail import MailAddress from eduid.vccs.client import VCCSClient, VCCSPasswordFactor from eduid.webapp.common.api import exceptions from eduid.webapp.idp.idp_authn import IdPAuthn @@ -49,26 +51,29 @@ def test_authn_unknown_user(self): assert pwauth is None @patch("eduid.vccs.client.VCCSClient.add_credentials") - def test_authn_known_user_wrong_password(self, mock_add_credentials): + def test_authn_known_user_wrong_password(self, mock_add_credentials: MagicMock): mock_add_credentials.return_value = False assert isinstance(self.test_user, eduid.userdb.User) assert isinstance(self.app.authn, IdPAuthn) # help pycharm cred_id = ObjectId() factor = VCCSPasswordFactor("foo", str(cred_id), salt=None) self.app.authn.auth_client.add_credentials(str(self.test_user.user_id), [factor]) + assert isinstance(self.test_user.mail_addresses.primary, MailAddress) pwauth = self.app.authn.password_authn(self.test_user.mail_addresses.primary.email, "bar") assert pwauth is None @patch("eduid.vccs.client.VCCSClient.authenticate") @patch("eduid.vccs.client.VCCSClient.add_credentials") - def test_authn_known_user_right_password(self, mock_add_credentials, mock_authenticate): + def test_authn_known_user_right_password(self, mock_add_credentials: MagicMock, mock_authenticate: MagicMock): mock_add_credentials.return_value = True mock_authenticate.return_value = True assert isinstance(self.test_user, eduid.userdb.User) assert isinstance(self.app.authn, IdPAuthn) # help pycharm passwords = self.test_user.credentials.to_list() + assert isinstance(passwords[0], Password) factor = VCCSPasswordFactor("foo", str(passwords[0].key), salt=passwords[0].salt) self.app.authn.auth_client.add_credentials(str(self.test_user.user_id), [factor]) + assert isinstance(self.test_user.mail_addresses.primary, MailAddress) pwauth = self.app.authn.password_authn(self.test_user.mail_addresses.primary.email, "foo") assert pwauth is not None assert pwauth.user.eppn == self.test_user.eppn @@ -77,21 +82,24 @@ def test_authn_known_user_right_password(self, mock_add_credentials, mock_authen @patch("eduid.vccs.client.VCCSClient.authenticate") @patch("eduid.vccs.client.VCCSClient.add_credentials") - def test_authn_expired_credential(self, mock_add_credentials, mock_authenticate): + def test_authn_expired_credential(self, mock_add_credentials: MagicMock, mock_authenticate: MagicMock): mock_add_credentials.return_value = False mock_authenticate.return_value = True assert isinstance(self.test_user, eduid.userdb.User) assert isinstance(self.app.authn, IdPAuthn) # help pycharm passwords = self.test_user.credentials.to_list() + assert isinstance(passwords[0], Password) factor = VCCSPasswordFactor("foo", str(passwords[0].key), salt=passwords[0].salt) self.app.authn.auth_client.add_credentials(str(self.test_user.user_id), [factor]) # Store a successful authentication using this credential three year ago three_years_ago = datetime.datetime.now() - datetime.timedelta(days=3 * 365) self.app.authn.authn_store.credential_success([passwords[0].key], three_years_ago) with self.assertRaises(exceptions.EduidForbidden): + assert isinstance(self.test_user.mail_addresses.primary, MailAddress) self.app.authn.password_authn(self.test_user.mail_addresses.primary.email, "foo") # Do the same thing again to make sure we didn't accidentally update the # 'last successful login' timestamp when it was a successful login with an # expired credential. with self.assertRaises(exceptions.EduidForbidden): + assert isinstance(self.test_user.mail_addresses.primary, MailAddress) self.app.authn.password_authn(self.test_user.mail_addresses.primary.email, "foo") diff --git a/src/eduid/webapp/letter_proofing/ekopost.py b/src/eduid/webapp/letter_proofing/ekopost.py index be7d0c000..ee541bbf1 100644 --- a/src/eduid/webapp/letter_proofing/ekopost.py +++ b/src/eduid/webapp/letter_proofing/ekopost.py @@ -1,6 +1,7 @@ import base64 import json from datetime import datetime +from io import BytesIO from hammock import Hammock @@ -23,7 +24,7 @@ def __init__(self, config: LetterProofingConfig): self.ekopost_api = Hammock(config.ekopost_api_uri, auth=auth, verify=config.ekopost_api_verify_ssl) - def send(self, eppn, document): + def send(self, eppn: str, document: BytesIO): """ Send a letter containing a PDF-document to the recipient specified in the document. @@ -63,7 +64,7 @@ def send(self, eppn, document): return closed_campaign["id"] - def _create_campaign(self, name, output_date, cost_center): + def _create_campaign(self, name: str, output_date: str, cost_center: str): """ Create a new campaign @@ -81,7 +82,9 @@ def _create_campaign(self, name, output_date, cost_center): raise EkopostException(f"Ekopost exception: {response.status_code!s} {response.text!s}") - def _create_envelope(self, campaign_id, name, postage="priority", plex="simplex", color="false"): + def _create_envelope( + self, campaign_id: str, name: str, postage: str = "priority", plex: str = "simplex", color: str = "false" + ): """ Create an envelope for a specified campaign @@ -105,7 +108,14 @@ def _create_envelope(self, campaign_id, name, postage="priority", plex="simplex" raise EkopostException(f"Ekopost exception: {response.status_code!s} {response.text!s}") - def _create_content(self, campaign_id, envelope_id, data, mime="application/pdf", content_type="document"): + def _create_content( + self, + campaign_id: str, + envelope_id: str, + data: bytes, + mime: str = "application/pdf", + content_type: str = "document", + ): """ Create the content that should be linked to an envelope @@ -137,7 +147,7 @@ def _create_content(self, campaign_id, envelope_id, data, mime="application/pdf" raise EkopostException(f"Ekopost exception: {response.status_code!s} {response.text!s}") - def _close_envelope(self, campaign_id, envelope_id): + def _close_envelope(self, campaign_id: str, envelope_id: str): """ Change an envelope state to closed and mark it as ready for print & distribution. :param campaign_id: Unique id of a campaign within which the envelope exists @@ -154,7 +164,7 @@ def _close_envelope(self, campaign_id, envelope_id): raise EkopostException(f"Ekopost exception: {response.status_code!s} {response.text!s}") - def _close_campaign(self, campaign_id): + def _close_campaign(self, campaign_id: str): """ Change a campains state to closed and mark it and all its envelopes as ready for print & distribution. diff --git a/src/eduid/webapp/letter_proofing/tests/test_app.py b/src/eduid/webapp/letter_proofing/tests/test_app.py index 53e7bceaf..7bc73a217 100644 --- a/src/eduid/webapp/letter_proofing/tests/test_app.py +++ b/src/eduid/webapp/letter_proofing/tests/test_app.py @@ -135,7 +135,7 @@ def verify_code(self, code: str, csrf_token: str | None = None, validate_respons @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") @patch("eduid.common.rpc.msg_relay.MsgRelay.get_postal_address") def _verify_code2( - self, code: str, csrf_token: str | None, mock_get_postal_address, mock_request_user_sync: MagicMock + self, code: str, csrf_token: str | None, mock_get_postal_address: MagicMock, mock_request_user_sync: MagicMock ): if csrf_token is None: _state = self.get_state() diff --git a/src/eduid/webapp/letter_proofing/tests/test_pdf.py b/src/eduid/webapp/letter_proofing/tests/test_pdf.py index a6704aa29..2d76c2b6a 100644 --- a/src/eduid/webapp/letter_proofing/tests/test_pdf.py +++ b/src/eduid/webapp/letter_proofing/tests/test_pdf.py @@ -2,6 +2,7 @@ from collections import OrderedDict from datetime import datetime from io import BytesIO, StringIO +from typing import Any from pypdf import PdfReader @@ -131,14 +132,14 @@ def test_failing_format(self): class CreatePDFTest(EduidAPITestCase): - def load_app(self, config): + def load_app(self, config: dict[str, Any]): """ Called from the parent class, so we can provide the appropriate flask app for this test case. """ return init_letter_proofing_app("testing", config) - def update_config(self, app_config): + def update_config(self, app_config: dict[str, Any]): app_config.update( { "letter_wait_time_hours": 336, diff --git a/src/eduid/webapp/letter_proofing/views.py b/src/eduid/webapp/letter_proofing/views.py index 73b475e7b..788c74984 100644 --- a/src/eduid/webapp/letter_proofing/views.py +++ b/src/eduid/webapp/letter_proofing/views.py @@ -23,7 +23,7 @@ @letter_proofing_views.route("/proofing", methods=["GET"]) @MarshalWith(schemas.LetterProofingResponseSchema) @require_user -def get_state(user) -> FluxData: +def get_state(user: User) -> FluxData: current_app.logger.info(f"Getting proofing state for user {user}") proofing_state = current_app.proofing_statedb.get_state_by_eppn(user.eppn) diff --git a/src/eduid/webapp/oidc_proofing/tests/test_app.py b/src/eduid/webapp/oidc_proofing/tests/test_app.py index e8a2d97ca..02784f4f7 100644 --- a/src/eduid/webapp/oidc_proofing/tests/test_app.py +++ b/src/eduid/webapp/oidc_proofing/tests/test_app.py @@ -8,6 +8,7 @@ from jose import jws as jose from eduid.userdb import NinIdentity +from eduid.userdb.proofing.state import OidcProofingState from eduid.webapp.common.api.testing import EduidAPITestCase from eduid.webapp.oidc_proofing.app import OIDCProofingApp, init_oidc_proofing_app from eduid.webapp.oidc_proofing.helpers import create_proofing_state, handle_freja_eid_userinfo @@ -55,7 +56,7 @@ def setUp(self, *args, **kwargs): } class MockResponse: - def __init__(self, status_code, text): + def __init__(self, status_code: int, text: str): self.status_code = status_code self.text = text @@ -63,7 +64,7 @@ def __init__(self, status_code, text): super().setUp(users=["hubba-baar"], *args, **kwargs) - def load_app(self, config) -> OIDCProofingApp: + def load_app(self, config: dict[str, Any]) -> OIDCProofingApp: """ Called from the parent class, so we can provide the appropriate flask app for this test case. @@ -95,7 +96,13 @@ def update_config(self, config: dict[str, Any]) -> dict[str, Any]: @patch("oic.oic.Client.do_user_info_request") @patch("oic.oic.Client.do_access_token_request") def mock_authorization_response( - self, qrdata, proofing_state, userinfo, mock_token_request, mock_userinfo_request, mock_auth_response + self, + qrdata: dict, + proofing_state: OidcProofingState, + userinfo: dict, + mock_token_request: MagicMock, + mock_userinfo_request: MagicMock, + mock_auth_response: MagicMock, ): mock_auth_response.return_value = { "id_token": "id_token", diff --git a/src/eduid/webapp/personal_data/views.py b/src/eduid/webapp/personal_data/views.py index 96d46716b..9a0bdd2bd 100644 --- a/src/eduid/webapp/personal_data/views.py +++ b/src/eduid/webapp/personal_data/views.py @@ -38,7 +38,7 @@ def get_all_data(user: User) -> FluxData: @pd_views.route("/identities", methods=["GET"]) @MarshalWith(IdentitiesResponseSchema) @require_user -def get_identities(user) -> FluxData: +def get_identities(user: User) -> FluxData: return success_response(payload={"identities": user.identities.to_frontend_format()}) diff --git a/src/eduid/webapp/phone/schemas.py b/src/eduid/webapp/phone/schemas.py index 65f32a91d..3235a065b 100644 --- a/src/eduid/webapp/phone/schemas.py +++ b/src/eduid/webapp/phone/schemas.py @@ -18,7 +18,7 @@ class PhoneSchema(EduidSchema, CSRFRequestMixin): primary = fields.Boolean(attribute="primary") @pre_load - def normalize_phone_number(self, in_data, **kwargs): + def normalize_phone_number(self, in_data: dict, **kwargs): if in_data.get("number"): in_data["number"] = normalize_to_e_164(in_data["number"]) return in_data diff --git a/src/eduid/webapp/phone/validators.py b/src/eduid/webapp/phone/validators.py index 8a82e243b..6a96f57f4 100644 --- a/src/eduid/webapp/phone/validators.py +++ b/src/eduid/webapp/phone/validators.py @@ -6,7 +6,7 @@ from eduid.webapp.phone.app import current_phone_app as current_app -def normalize_to_e_164(number): +def normalize_to_e_164(number: str): number = "".join(number.split()) # Remove white space if number.startswith("00"): raise ValidationError("phone.e164_format") @@ -16,24 +16,24 @@ def normalize_to_e_164(number): return number -def validate_phone(number): +def validate_phone(number: str): validate_format_phone(number) validate_swedish_mobile(number) validate_unique_phone(number) -def validate_format_phone(number): +def validate_format_phone(number: str): if not re.match(r"^\+[1-9]\d{6,20}$", number): raise ValidationError("phone.phone_format") -def validate_swedish_mobile(number): +def validate_swedish_mobile(number: str): if number.startswith("+467"): if not re.match(r"^\+467[02369]\d{7}$", number): raise ValidationError("phone.swedish_mobile_format") -def validate_unique_phone(number): +def validate_unique_phone(number: str): user = get_user() if user.phone_numbers.find(number): raise ValidationError("phone.phone_duplicated") diff --git a/src/eduid/webapp/phone/views.py b/src/eduid/webapp/phone/views.py index ad65591ff..891a29c7c 100644 --- a/src/eduid/webapp/phone/views.py +++ b/src/eduid/webapp/phone/views.py @@ -42,7 +42,7 @@ def get_all_phones(user: User) -> FluxData: @UnmarshalWith(PhoneSchema) @MarshalWith(PhoneResponseSchema) @require_user -def post_phone(user: User, number: str, verified=None, primary=None) -> FluxData: +def post_phone(user: User, number: str, verified: bool | None = None, primary: bool | None = None) -> FluxData: """ view to add a new phone to the user data of the currently logged in user. diff --git a/src/eduid/webapp/reset_password/tests/test_app.py b/src/eduid/webapp/reset_password/tests/test_app.py index 3d71a7499..df595866f 100644 --- a/src/eduid/webapp/reset_password/tests/test_app.py +++ b/src/eduid/webapp/reset_password/tests/test_app.py @@ -18,7 +18,7 @@ from eduid.userdb.reset_password import ResetPasswordEmailAndPhoneState, ResetPasswordEmailState from eduid.webapp.common.api.testing import EduidAPITestCase from eduid.webapp.common.api.utils import get_zxcvbn_terms, hash_password -from eduid.webapp.common.authn.testing import TestVCCSClient +from eduid.webapp.common.authn.testing import MockVCCSClient from eduid.webapp.common.authn.tests.test_fido_tokens import ( SAMPLE_WEBAUTHN_APP_CONFIG, SAMPLE_WEBAUTHN_FIDO2STATE, @@ -137,7 +137,7 @@ def _post_reset_password( :param data2: control the data sent to actually reset the password. """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() # check that the user has verified data user = self.app.central_userdb.get_user_by_eppn(self.test_user.eppn) @@ -198,7 +198,7 @@ def _post_choose_extra_sec( :param repeat: if True, try to trigger sending the SMS twice. """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() mock_sendsms.return_value = True if sendsms_side_effect: mock_sendsms.side_effect = sendsms_side_effect @@ -257,7 +257,7 @@ def _post_reset_password_secure_phone( :param data2: To control the data sent to actually finally reset the password. """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() mock_sendsms.return_value = True response = self._post_email_address(data1=data1) @@ -320,7 +320,7 @@ def _post_reset_password_secure_token( :param fido2state: to control the fido state kept in the session """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() mock_verify.return_value = True credential = sample_credential.to_dict() @@ -383,7 +383,7 @@ def _post_reset_password_secure_external_mfa( :param external_mfa_state: to control the external mfa state kept in the session """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() user = self.app.central_userdb.get_user_by_eppn(self.test_user.eppn) @@ -457,7 +457,7 @@ def _get_phone_code_backdoor( and getting the generated phone verification code through the backdoor """ mock_request_user_sync.side_effect = self.request_user_sync - mock_get_vccs_client.return_value = TestVCCSClient() + mock_get_vccs_client.return_value = MockVCCSClient() mock_sendsms.return_value = True if sendsms_side_effect: mock_sendsms.side_effect = sendsms_side_effect diff --git a/src/eduid/webapp/security/helpers.py b/src/eduid/webapp/security/helpers.py index 468c1b024..4ca271bb9 100644 --- a/src/eduid/webapp/security/helpers.py +++ b/src/eduid/webapp/security/helpers.py @@ -167,7 +167,7 @@ def generate_suggested_password() -> str: return password -def send_termination_mail(user): +def send_termination_mail(user: User): """ :param user: User object :type user: User diff --git a/src/eduid/webapp/security/schemas.py b/src/eduid/webapp/security/schemas.py index 2a4637837..bab2760c2 100644 --- a/src/eduid/webapp/security/schemas.py +++ b/src/eduid/webapp/security/schemas.py @@ -59,7 +59,7 @@ class ChangePasswordSchema(PasswordSchema): authn_id = fields.String(required=False) @validates("new_password") - def validate_custom_password(self, value, **kwargs): + def validate_custom_password(self, value: str, **kwargs): # Set a new error message try: self.validate_password(value) diff --git a/src/eduid/webapp/security/tests/test_change_password.py b/src/eduid/webapp/security/tests/test_change_password.py index 46f116113..a3a1fcce0 100644 --- a/src/eduid/webapp/security/tests/test_change_password.py +++ b/src/eduid/webapp/security/tests/test_change_password.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch from eduid.common.config.base import FrontendAction from eduid.userdb.credentials import Password @@ -145,7 +145,7 @@ def test_get_suggested_not_logged_in(self): self.assertEqual(response.status_code, 401) @patch("eduid.webapp.security.views.change_password.generate_suggested_password") - def test_get_suggested(self, mock_generate_password): + def test_get_suggested(self, mock_generate_password: MagicMock): mock_generate_password.return_value = "test-password" self.set_authn_action( @@ -161,7 +161,7 @@ def test_get_suggested(self, mock_generate_password): ) @patch("eduid.webapp.security.views.change_password.change_password") - def test_change_passwd(self, mock_change_password): + def test_change_passwd(self, mock_change_password: MagicMock): mock_change_password.return_value = True self.set_authn_action( @@ -177,7 +177,7 @@ def test_change_passwd(self, mock_change_password): ) @patch("eduid.webapp.security.views.change_password.change_password") - def test_change_passwd_with_login_auth(self, mock_change_password): + def test_change_passwd_with_login_auth(self, mock_change_password: MagicMock): mock_change_password.return_value = True self.set_authn_action( @@ -213,7 +213,7 @@ def test_change_passwd_empty_data(self): ) @patch("eduid.webapp.security.views.change_password.change_password") - def test_change_passwd_no_csrf(self, mock_change_password): + def test_change_passwd_no_csrf(self, mock_change_password: MagicMock): mock_change_password.return_value = True data1 = {"csrf_token": ""} @@ -225,7 +225,7 @@ def test_change_passwd_no_csrf(self, mock_change_password): ) @patch("eduid.webapp.security.views.change_password.change_password") - def test_change_passwd_wrong_csrf(self, mock_change_password): + def test_change_passwd_wrong_csrf(self, mock_change_password: MagicMock): mock_change_password.return_value = True data1 = {"csrf_token": "wrong-token"} @@ -237,7 +237,7 @@ def test_change_passwd_wrong_csrf(self, mock_change_password): ) @patch("eduid.webapp.security.views.change_password.change_password") - def test_change_passwd_weak(self, mock_change_password): + def test_change_passwd_weak(self, mock_change_password: MagicMock): mock_change_password.return_value = True self.set_authn_action( diff --git a/src/eduid/webapp/security/tests/test_webauthn.py b/src/eduid/webapp/security/tests/test_webauthn.py index bcdc101b3..374db3a4d 100644 --- a/src/eduid/webapp/security/tests/test_webauthn.py +++ b/src/eduid/webapp/security/tests/test_webauthn.py @@ -10,7 +10,7 @@ from eduid.common.config.base import EduidEnvironment, FrontendAction from eduid.userdb.credentials import U2F, FidoCredential, Webauthn -from eduid.webapp.common.api.testing import EduidAPITestCase +from eduid.webapp.common.api.testing import CSRFTestClient, EduidAPITestCase from eduid.webapp.common.session import EduidSession from eduid.webapp.common.session.namespaces import WebauthnRegistration, WebauthnState from eduid.webapp.security.app import SecurityApp, security_init_app @@ -166,7 +166,7 @@ def _add_u2f_token_to_user(self, eppn: str) -> U2F: self.app.central_userdb.save(user) return u2f_token - def _check_session_state(self, client): + def _check_session_state(self, client: CSRFTestClient): with client.session_transaction() as sess: assert isinstance(sess, EduidSession) assert sess.security.webauthn_registration is not None @@ -174,17 +174,17 @@ def _check_session_state(self, client): assert webauthn_state["user_verification"] == "discouraged" assert "challenge" in webauthn_state - def _check_registration_begun(self, data): + def _check_registration_begun(self, data: dict): self.assertEqual(data["type"], "POST_WEBAUTHN_WEBAUTHN_REGISTER_BEGIN_SUCCESS") self.assertIn("registration_data", data["payload"]) self.assertIn("csrf_token", data["payload"]) - def _check_registration_complete(self, data): + def _check_registration_complete(self, data: dict): self.assertEqual(data["type"], "POST_WEBAUTHN_WEBAUTHN_REGISTER_COMPLETE_SUCCESS") self.assertTrue(len(data["payload"]["credentials"]) > 0) self.assertEqual(data["payload"]["message"], "security.webauthn_register_success") - def _check_removal(self, data, user_token): + def _check_removal(self, data: dict, user_token: Webauthn): self.assertEqual(data["type"], "POST_WEBAUTHN_WEBAUTHN_REMOVE_SUCCESS") self.assertIsNotNone(data["payload"]["credentials"]) for credential in data["payload"]["credentials"]: diff --git a/src/eduid/webapp/signup/schemas.py b/src/eduid/webapp/signup/schemas.py index 7eb947d7c..4ec88b7df 100644 --- a/src/eduid/webapp/signup/schemas.py +++ b/src/eduid/webapp/signup/schemas.py @@ -62,19 +62,19 @@ class Credentials(EduidSchema): payload = fields.Nested(StatusSchema) @pre_dump - def set_already_signed_up(self, data, **kwargs): + def set_already_signed_up(self, data: dict, **kwargs): if data["payload"].get("state"): data["payload"]["state"]["already_signed_up"] = bool(session.common.eppn) return data @pre_dump - def set_tou_version(self, data, **kwargs): + def set_tou_version(self, data: dict, **kwargs): if data["payload"].get("state", {}).get("tou") and data["payload"]["state"]["tou"].get("version") is None: data["payload"]["state"]["tou"]["version"] = current_app.conf.tou_version return data @pre_dump - def throttle_delta_to_seconds(self, out_data, **kwargs): + def throttle_delta_to_seconds(self, out_data: dict, **kwargs): if out_data["payload"].get("state", {}).get("email", {}).get("sent_at"): sent_at = out_data["payload"]["state"]["email"]["sent_at"] throttle_time_left = time_left(sent_at, current_app.conf.throttle_resend).total_seconds() @@ -86,7 +86,7 @@ def throttle_delta_to_seconds(self, out_data, **kwargs): return out_data @pre_dump - def email_verification_timeout_delta_to_seconds(self, out_data, **kwargs): + def email_verification_timeout_delta_to_seconds(self, out_data: dict, **kwargs): if out_data["payload"].get("state", {}).get("email", {}).get("sent_at"): sent_at = out_data["payload"]["state"]["email"]["sent_at"] verification_time_left = time_left(sent_at, current_app.conf.email_verification_timeout).total_seconds() @@ -98,7 +98,7 @@ def email_verification_timeout_delta_to_seconds(self, out_data, **kwargs): return out_data @pre_dump - def bad_attempts_max(self, out_data, **kwargs): + def bad_attempts_max(self, out_data: dict, **kwargs): if out_data["payload"].get("state", {}).get("email"): out_data["payload"]["state"]["email"]["bad_attempts_max"] = ( current_app.conf.email_verification_max_bad_attempts diff --git a/src/eduid/webapp/support/app.py b/src/eduid/webapp/support/app.py index 0c9ed3b8a..be1afe9af 100644 --- a/src/eduid/webapp/support/app.py +++ b/src/eduid/webapp/support/app.py @@ -1,5 +1,6 @@ import operator from collections.abc import Mapping +from datetime import datetime from typing import Any, cast from flask import current_app @@ -33,19 +34,19 @@ def __init__(self, config: SupportConfig, **kwargs): def register_template_funcs(app: SupportApp) -> None: @app.template_filter("datetimeformat") - def datetimeformat(value, format="%Y-%m-%d %H:%M %Z"): + def datetimeformat(value: datetime | None, format: str = "%Y-%m-%d %H:%M %Z"): if not value: return "" return value.strftime(format) @app.template_filter("dateformat") - def dateformat(value, format="%Y-%m-%d"): + def dateformat(value: datetime | None, format: str = "%Y-%m-%d"): if not value: return "" return value.strftime(format) @app.template_filter("multisort") - def sort_multi(items, *operators, **kwargs): + def sort_multi(items: list, *operators, **kwargs): # Don't try to sort on missing keys keys = list(operators) # operators is immutable for key in operators: diff --git a/src/eduid/webapp/svipe_id/helpers.py b/src/eduid/webapp/svipe_id/helpers.py index f7ea2f600..39965a7c0 100644 --- a/src/eduid/webapp/svipe_id/helpers.py +++ b/src/eduid/webapp/svipe_id/helpers.py @@ -83,7 +83,7 @@ class SvipeDocumentUserInfo(UserInfoBase): @field_validator("document_nationality") @classmethod - def iso_3166_1_alpha_3_to_alpha2(cls, v): + def iso_3166_1_alpha_3_to_alpha2(cls, v: Any): # translate ISO 3166-1 alpha-3 to alpha-2 to match the format used in eduid-userdb try: country = countries.get(v) diff --git a/src/eduid/webapp/svipe_id/views.py b/src/eduid/webapp/svipe_id/views.py index 4b249e93f..ecbdb625a 100644 --- a/src/eduid/webapp/svipe_id/views.py +++ b/src/eduid/webapp/svipe_id/views.py @@ -132,7 +132,7 @@ def _authn( @svipe_id_views.route("/authn-callback", methods=["GET"]) @require_user -def authn_callback(user) -> WerkzeugResponse: +def authn_callback(user: User) -> WerkzeugResponse: """ This is the callback endpoint for the Svipe ID OIDC flow. """ diff --git a/src/eduid/workers/am/tasks.py b/src/eduid/workers/am/tasks.py index beef6a0d8..a30842cfd 100644 --- a/src/eduid/workers/am/tasks.py +++ b/src/eduid/workers/am/tasks.py @@ -1,4 +1,5 @@ import bson +from billiard.einfo import ExceptionInfo from celery import Task from celery.utils.log import get_task_logger @@ -30,7 +31,7 @@ def userdb(self) -> AmDB | None: self._userdb = AmDB(AmCelerySingleton.worker_config.mongo_uri, "eduid_am") return self._userdb - def on_failure(self, exc, task_id, args, kwargs, einfo): + def on_failure(self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo: ExceptionInfo): # The most common problem when tasks raise exceptions is that mongodb has switched master, # but it is hard to accurately trap the right exception without importing pymongo here so # let's just reload all databases (self.userdb here and the plugins databases) when we diff --git a/src/eduid/workers/am/testing.py b/src/eduid/workers/am/testing.py index 79df57ed9..82b662d98 100644 --- a/src/eduid/workers/am/testing.py +++ b/src/eduid/workers/am/testing.py @@ -106,9 +106,7 @@ class WorkerTestCase(CommonTestCase): Base Test case for eduID celery workers """ - def setUp( # type: ignore[override] - self, *args: Any, am_settings: dict[str, Any] | None = None, want_mongo_uri: bool = True, **kwargs: Any - ): + def setUp(self, *args: Any, am_settings: dict[str, Any] | None = None, want_mongo_uri: bool = True, **kwargs: Any): """ set up tests """ diff --git a/src/eduid/workers/am/tests/test_am.py b/src/eduid/workers/am/tests/test_am.py index b7d9f47c1..7a9071314 100644 --- a/src/eduid/workers/am/tests/test_am.py +++ b/src/eduid/workers/am/tests/test_am.py @@ -4,6 +4,7 @@ from eduid.common.config.workers import AmConfig from eduid.userdb.db import TUserDbDocument from eduid.userdb.exceptions import UserDoesNotExist +from eduid.userdb.user import User from eduid.workers.am.ams.common import AttributeFetcher from eduid.workers.am.common import AmCelerySingleton from eduid.workers.am.testing import AMTestCase @@ -43,11 +44,14 @@ class FakeAttributeFetcher(AttributeFetcher): :rtype: dict """ - def get_user_db(self, uri): + @classmethod + def get_user_db(cls, uri: str): return AmTestUserDb(uri, db_name="eduid_am_test") - def fetch_attrs(self, user_id): - user = self.private_db.get_user_by_id(user_id) + def fetch_attrs(self, user_id: ObjectId): + assert self.private_db + user: User | None = self.private_db.get_user_by_id(user_id) + assert isinstance(user, AmTestUser) if user is None: raise UserDoesNotExist(f"No user matching _id={user_id!r}") @@ -66,7 +70,7 @@ class BadAttributeFetcher(FakeAttributeFetcher): Returns a bad operations dict. """ - def fetch_attrs(self, user_id): + def fetch_attrs(self, user_id: ObjectId): res = super().fetch_attrs(user_id) res["notanoperator"] = "test" return res diff --git a/src/eduid/workers/amapi/middleware.py b/src/eduid/workers/amapi/middleware.py index b2a11c3c4..692847152 100644 --- a/src/eduid/workers/amapi/middleware.py +++ b/src/eduid/workers/amapi/middleware.py @@ -5,7 +5,7 @@ from fastapi import Request, Response, status from jwcrypto import jwt from jwcrypto.common import JWException -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import PlainTextResponse from eduid.workers.amapi.config import EndpointRestriction @@ -28,7 +28,7 @@ def return_error_response(status_code: int, detail: str): class AuthenticationMiddleware(BaseHTTPMiddleware, ContextRequestMixin): - async def dispatch(self, req: Request, call_next) -> Response: + async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req) path = req.url.path.lstrip(req.app.config.application_root) method_path = f"{req.method.lower()}:{path}" diff --git a/src/eduid/workers/amapi/routers/utils/status.py b/src/eduid/workers/amapi/routers/utils/status.py index fa4a933b5..d671bece8 100644 --- a/src/eduid/workers/amapi/routers/utils/status.py +++ b/src/eduid/workers/amapi/routers/utils/status.py @@ -47,7 +47,7 @@ def reset_failure_info(ctx: ContextRequest, key: str) -> None: ctx.app.context.logger.info(f"Check {key} back to normal. Resetting info {info}") -def check_restart(key, restart: int, terminate: int) -> bool: +def check_restart(key: str, restart: int, terminate: int) -> bool: res = False # default to no restart info = FAILURE_INFO.get(key) if not info: diff --git a/src/eduid/workers/lookup_mobile/client/mobile_lookup_client.py b/src/eduid/workers/lookup_mobile/client/mobile_lookup_client.py index 2f1300241..8f7570656 100644 --- a/src/eduid/workers/lookup_mobile/client/mobile_lookup_client.py +++ b/src/eduid/workers/lookup_mobile/client/mobile_lookup_client.py @@ -1,4 +1,8 @@ +from logging import Logger +from typing import Any + from suds.client import Client +from suds.sudsobject import Object from eduid.common.config.base import EduidEnvironment from eduid.common.config.workers import MobConfig @@ -9,7 +13,7 @@ class MobileLookupClient: - def __init__(self, logger, config: MobConfig) -> None: + def __init__(self, logger: Logger, config: MobConfig) -> None: self.conf = config # enable transaction logging if configured @@ -32,7 +36,7 @@ def _get_find_person(self): @TransactionAudit() @deprecated("This task seems unused") - def find_mobiles_by_NIN(self, national_identity_number: str, number_region=None) -> list[str]: + def find_mobiles_by_NIN(self, national_identity_number: str, number_region: str | None = None) -> list[str]: formatted_nin = format_NIN(national_identity_number) if not formatted_nin: self.logger.error(f"Invalid NIN input: {national_identity_number}") @@ -47,7 +51,7 @@ def find_mobiles_by_NIN(self, national_identity_number: str, number_region=None) return format_mobile_number(mobiles, number_region) @TransactionAudit() - def find_NIN_by_mobile(self, mobile_number) -> str | None: + def find_NIN_by_mobile(self, mobile_number: str) -> str | None: nin = self._search_by_mobile(mobile_number) if not nin: self.logger.debug(f"Did not get search result from mobile number: {mobile_number}") @@ -55,7 +59,7 @@ def find_NIN_by_mobile(self, mobile_number) -> str | None: return format_NIN(nin) - def _search(self, param): + def _search(self, param: Any | Object): # Start the search # TODO: remove self.conf.devel_mode, use environment instead if self.conf.testing or self.conf.environment == EduidEnvironment.dev: diff --git a/src/eduid/workers/lookup_mobile/decorators.py b/src/eduid/workers/lookup_mobile/decorators.py index e04110193..533cefd96 100644 --- a/src/eduid/workers/lookup_mobile/decorators.py +++ b/src/eduid/workers/lookup_mobile/decorators.py @@ -5,22 +5,27 @@ # logging module at a later stage. # +from collections.abc import Callable from datetime import datetime from inspect import isclass +from typing import Any + +from pymongo.collection import Collection from eduid.userdb.db import MongoDB +from eduid.userdb.db.base import TUserDbDocument class TransactionAudit: enabled = True db_uri = None - def __init__(self, db_name="eduid_lookup_mobile", collection_name="transaction_audit"): - self.db_name = db_name - self.collection_name = collection_name - self.collection = None + def __init__(self, db_name: str = "eduid_lookup_mobile", collection_name: str = "transaction_audit"): + self.db_name: str = db_name + self.collection_name: str = collection_name + self.collection: Collection[TUserDbDocument] | None = None - def __call__(self, f): + def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f @@ -38,11 +43,13 @@ def audit(*args, **kwargs): self.collection = db.get_collection(self.collection_name) if not isclass(ret): # we can't save class objects in mongodb date = datetime.utcnow() - doc = { - "function": f.__name__, - "data": self._filter(f.__name__, ret, *args, **kwargs), - "created_at": date, - } + doc = TUserDbDocument( + { + "function": f.__name__, + "data": self._filter(f.__name__, ret, *args, **kwargs), + "created_at": date, + } + ) self.collection.insert_one(doc) return ret @@ -56,7 +63,7 @@ def enable(cls): def disable(cls): cls.enabled = False - def _filter(self, func, data, *args, **kwargs): + def _filter(self, func: str, data: Any, *args, **kwargs): if data is False: return data if func == "find_mobiles_by_NIN": diff --git a/src/eduid/workers/lookup_mobile/development/development_search_result.py b/src/eduid/workers/lookup_mobile/development/development_search_result.py index 0a4d33954..b5a6d31db 100644 --- a/src/eduid/workers/lookup_mobile/development/development_search_result.py +++ b/src/eduid/workers/lookup_mobile/development/development_search_result.py @@ -1,11 +1,14 @@ __author__ = "mathiashedstrom" +from typing import Any + +from suds.sudsobject import Object from eduid.workers.lookup_mobile.development import nin_mobile_db class DevelopResult: class Record: - def __init__(self, nin, mobile): + def __init__(self, nin: str, mobile: str): self.SSNo = nin self.Mobiles = mobile @@ -14,7 +17,7 @@ def __init__(self): self._num_records = 0 self.record = [] - def append_record(self, record): + def append_record(self, record: "DevelopResult.Record"): self.record.append(record) self._num_records = len(self.record) @@ -23,7 +26,7 @@ def __init__(self): self._error_code = 0 -def _get_devel_search_result(search_param): +def _get_devel_search_result(search_param: Any | Object): nin = search_param.QueryParams.FindSSNo mobile = search_param.QueryParams.FindTelephone diff --git a/src/eduid/workers/lookup_mobile/development/nin_mobile_db.py b/src/eduid/workers/lookup_mobile/development/nin_mobile_db.py index d25ee605f..dcab1907b 100644 --- a/src/eduid/workers/lookup_mobile/development/nin_mobile_db.py +++ b/src/eduid/workers/lookup_mobile/development/nin_mobile_db.py @@ -12,11 +12,11 @@ } -def get_mobile(nin): +def get_mobile(nin: str): return _db.get(nin, []) -def get_nin(mobile): +def get_nin(mobile: str): for nin, numbers in _db.items(): if mobile in numbers: return nin diff --git a/src/eduid/workers/lookup_mobile/tasks.py b/src/eduid/workers/lookup_mobile/tasks.py index fa201b292..1727d3d5a 100644 --- a/src/eduid/workers/lookup_mobile/tasks.py +++ b/src/eduid/workers/lookup_mobile/tasks.py @@ -27,7 +27,7 @@ def lookup_client(self) -> MobileLookupClient: @app.task(bind=True, base=MobWorker) @deprecated("This task seems unused") -def find_mobiles_by_NIN(self: MobWorker, national_identity_number: str, number_region=None) -> list[str]: +def find_mobiles_by_NIN(self: MobWorker, national_identity_number: str, number_region: str | None = None) -> list[str]: """ Searches mobile numbers registered to the given nin :param national_identity_number: diff --git a/src/eduid/workers/lookup_mobile/test/test_decorators.py b/src/eduid/workers/lookup_mobile/test/test_decorators.py index b1950b040..b05e3e1b0 100644 --- a/src/eduid/workers/lookup_mobile/test/test_decorators.py +++ b/src/eduid/workers/lookup_mobile/test/test_decorators.py @@ -1,5 +1,7 @@ __author__ = "lundberg" +from typing import Any + from eduid.common.config.workers import MsgConfig from eduid.workers.lookup_mobile.decorators import TransactionAudit from eduid.workers.lookup_mobile.testing import LookupMobileMongoTestCase @@ -16,7 +18,7 @@ def setUp(self): def test_successfull_transaction_audit(self): @TransactionAudit() - def find_mobiles_by_NIN(self, national_identity_number, number_region=None): + def find_mobiles_by_NIN(self: Any, national_identity_number: str, number_region: str | None = None): return ["list", "of", "mobile_numbers"] find_mobiles_by_NIN(self, "200202025678") @@ -29,7 +31,7 @@ def find_mobiles_by_NIN(self, national_identity_number, number_region=None): c.delete_many({}) # Clear database @TransactionAudit() - def find_NIN_by_mobile(self, mobile_number): + def find_NIN_by_mobile(self: Any, mobile_number: str): return "200202025678" find_NIN_by_mobile(self, "+46701740699") @@ -43,7 +45,7 @@ def find_NIN_by_mobile(self, mobile_number): def test_failed_transaction_audit(self): @TransactionAudit() - def find_mobiles_by_NIN(self, national_identity_number, number_region=None): + def find_mobiles_by_NIN(self: Any, national_identity_number: str, number_region: str | None = None): return [] find_mobiles_by_NIN(self, "200202025678") @@ -54,7 +56,7 @@ def find_mobiles_by_NIN(self, national_identity_number, number_region=None): c.delete_many({}) # Clear database @TransactionAudit() - def find_NIN_by_mobile(self, mobile_number): + def find_NIN_by_mobile(self: Any, mobile_number: str): return find_NIN_by_mobile(self, "+46701740699") @@ -70,7 +72,7 @@ def test_transaction_audit_toggle(self): TransactionAudit.disable() @TransactionAudit() - def no_name(self): + def no_name(self: Any): return {"baka": "kaka"} no_name(self) @@ -81,7 +83,7 @@ def no_name(self): TransactionAudit.enable() @TransactionAudit() - def no_name2(self): + def no_name2(self: Any): return {"baka": "kaka"} no_name2(self) diff --git a/src/eduid/workers/lookup_mobile/testing.py b/src/eduid/workers/lookup_mobile/testing.py index 33691f649..fd266bebd 100644 --- a/src/eduid/workers/lookup_mobile/testing.py +++ b/src/eduid/workers/lookup_mobile/testing.py @@ -5,6 +5,7 @@ from eduid.common.config.workers import MobConfig from eduid.common.rpc.lookup_mobile_relay import LookupMobileRelay from eduid.userdb.testing import MongoTestCase +from eduid.userdb.user import User from eduid.workers.lookup_mobile.common import MobCelerySingleton logger = logging.getLogger(__name__) @@ -15,8 +16,8 @@ class MobTestConfig(EduIDBaseAppConfig, CeleryConfigMixin): class LookupMobileMongoTestCase(MongoTestCase): - def setUp(self, init_lookup_mobile=True, **kwargs) -> Any: # type: ignore[override] - super().setUp(**kwargs) + def setUp(self, am_users: list[User] | None = None, init_lookup_mobile: bool = True) -> Any: + super().setUp(am_users=am_users) if init_lookup_mobile: settings = { "app_name": "testing", diff --git a/src/eduid/workers/msg/decorators.py b/src/eduid/workers/msg/decorators.py index 0990dd7cb..012d3e381 100644 --- a/src/eduid/workers/msg/decorators.py +++ b/src/eduid/workers/msg/decorators.py @@ -52,7 +52,7 @@ def disable(cls): cls.enabled = False @staticmethod - def _filter(func, data, *args, **kwargs): + def _filter(func: str, data: Any, *args, **kwargs): if data is False: return data if func == "_get_navet_data": diff --git a/src/eduid/workers/msg/exceptions.py b/src/eduid/workers/msg/exceptions.py index 7406c75d0..ea33d98c0 100644 --- a/src/eduid/workers/msg/exceptions.py +++ b/src/eduid/workers/msg/exceptions.py @@ -3,7 +3,7 @@ class MessageException(Exception): class NavetException(Exception): - def __init__(self, message): + def __init__(self, message: str): self.message = message def __str__(self): diff --git a/src/eduid/workers/msg/tasks.py b/src/eduid/workers/msg/tasks.py index 6e841e51b..c8833312a 100644 --- a/src/eduid/workers/msg/tasks.py +++ b/src/eduid/workers/msg/tasks.py @@ -4,6 +4,7 @@ from collections import OrderedDict from typing import Any +from billiard.einfo import ExceptionInfo from celery import Task from celery.utils.log import get_task_logger from hammock import Hammock @@ -88,7 +89,7 @@ def reload_db(): # Remove initiated cache dbs _CACHE = {} - def on_failure(self, exc, task_id, args, kwargs, einfo): + def on_failure(self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo: ExceptionInfo): # Try to reload the db on connection failures (mongodb has probably switched master) if isinstance(exc, ConnectionError): logger.error("Task failed with db exception ConnectionError. Reloading db.") diff --git a/src/eduid/workers/msg/testing.py b/src/eduid/workers/msg/testing.py index 219969ad0..d798ec27a 100644 --- a/src/eduid/workers/msg/testing.py +++ b/src/eduid/workers/msg/testing.py @@ -7,6 +7,7 @@ from eduid.common.rpc.mail_relay import MailRelay from eduid.common.rpc.msg_relay import MsgRelay from eduid.userdb.testing import MongoTestCase +from eduid.userdb.user import User from eduid.workers.msg.common import MsgCelerySingleton logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ class MailTestConfig(EduIDBaseAppConfig, MailConfigMixin): class MsgMongoTestCase(MongoTestCase): - def setUp(self, init_msg=True) -> Any: # type: ignore[override] + def setUp(self, am_users: list[User] | None = None, init_msg: bool = True) -> Any: super().setUp() data_path = PurePath(__file__).with_name("tests") / "data" if init_msg: diff --git a/src/eduid/workers/msg/tests/__init__.py b/src/eduid/workers/msg/tests/__init__.py index 3d79fad47..e69de29bb 100644 --- a/src/eduid/workers/msg/tests/__init__.py +++ b/src/eduid/workers/msg/tests/__init__.py @@ -1,15 +0,0 @@ -from unittest.mock import MagicMock - - -def mock_get_attribute_manager(celery): - """ - Mocked get function for an attribute manager we don't need here - :return: Mocked am - :rtype: Object - """ - am = MagicMock() - return am - - -# Mocked celery for am that we don't need here -mock_celery = MagicMock() diff --git a/src/eduid/workers/msg/tests/test_decorators.py b/src/eduid/workers/msg/tests/test_decorators.py index 17f412d79..8cf5d25e3 100644 --- a/src/eduid/workers/msg/tests/test_decorators.py +++ b/src/eduid/workers/msg/tests/test_decorators.py @@ -1,10 +1,14 @@ +from typing import Any + +from eduid.userdb.user import User from eduid.workers.msg.decorators import TransactionAudit from eduid.workers.msg.testing import MsgMongoTestCase class TestTransactionAudit(MsgMongoTestCase): - def setUp(self, init_msg=True): + def setUp(self, am_users: list[User] | None = None, init_msg: bool = True): super().setUp(init_msg=init_msg) + assert self.msg_settings.mongo_uri TransactionAudit.enable(self.msg_settings.mongo_uri, db_name="test") def test_transaction_audit(self): @@ -24,7 +28,7 @@ def no_name(): assert result.next()["data"]["baka"] == "kaka" @TransactionAudit() - def _get_navet_data(arg1, arg2): + def _get_navet_data(arg1: str, arg2: str): return {"baka", "kaka"} _get_navet_data("dummy", "1111") @@ -32,7 +36,16 @@ def _get_navet_data(arg1, arg2): self.assertEqual(result["data"]["identity_number"], "1111") @TransactionAudit() - def send_message(_self, message_type, reference, message_dict, recipient, template, language, subject=None): + def send_message( + _self: Any, + message_type: str, + reference: str, + message_dict: str, + recipient: str, + template: str, + language: str, + subject: str | None = None, + ): return "kaka" send_message("dummy", "mm", "reference", "dummy", "2222", "template", "lang") diff --git a/src/eduid/workers/msg/tests/test_tasks.py b/src/eduid/workers/msg/tests/test_tasks.py index 9b8db27c1..84453d2f3 100644 --- a/src/eduid/workers/msg/tests/test_tasks.py +++ b/src/eduid/workers/msg/tests/test_tasks.py @@ -3,6 +3,7 @@ import pytest from celery.exceptions import Retry +from eduid.userdb.user import User from eduid.workers.msg.testing import MsgMongoTestCase @@ -11,7 +12,7 @@ class MockException(Exception): class TestTasks(MsgMongoTestCase): - def setUp(self, init_msg=True): + def setUp(self, am_users: list[User] | None = None, init_msg: bool = True): super().setUp(init_msg=init_msg) self.msg_dict = {"name": "Godiskungen", "admin": "Testadmin"} @@ -66,7 +67,7 @@ def json(self): } @patch("smscom.SMSClient.send") - def test_send_message_sms(self, sms_mock): + def test_send_message_sms(self, sms_mock: MagicMock): sms_mock.return_value = True self.msg_relay.sendsms(recipient="+466666", message="foo", reference="ref") @@ -87,10 +88,8 @@ def test_send_message_invalid_phone_number(self): assert exc_info.value.excs == "ValueError(\"'to' is not a valid phone number\")" - @patch( - "smscom.SMSClient.send", - ) - def test_send_message_sms_exception(self, sms_mock): + @patch("smscom.SMSClient.send") + def test_send_message_sms_exception(self, sms_mock: MagicMock): """Test creating an artificial exception in the SMSClient.send""" sms_mock.side_effect = MockException("Unrecoverable error") with pytest.raises(Retry) as exc_info: From aabe945242913eb40946b0addd897a89da214495 Mon Sep 17 00:00:00 2001 From: Lasse Yledahl Date: Fri, 27 Sep 2024 14:15:09 +0000 Subject: [PATCH 02/16] add annotations for args and kwargs where missing --- ruff.toml | 2 +- .../clients/amapi_client/amapi_client.py | 2 +- .../clients/gnap_client/async_client.py | 3 +- .../common/clients/gnap_client/sync_client.py | 3 +- .../common/clients/scim_client/scim_client.py | 3 +- src/eduid/common/config/parsers/decorators.py | 4 +- src/eduid/common/decorators.py | 5 +- src/eduid/common/fastapi/context_request.py | 3 +- src/eduid/common/fastapi/exceptions.py | 11 +- src/eduid/queue/db/client.py | 3 +- src/eduid/queue/decorators.py | 4 +- src/eduid/satosa/scimapi/serve_static.py | 3 +- src/eduid/scimapi/context.py | 3 +- src/eduid/scimapi/exceptions.py | 13 +- src/eduid/scimapi/utils.py | 4 +- .../userdb/tests/test_group_management.py | 2 +- src/eduid/userdb/tests/test_userdb.py | 161 ++++++++++-------- src/eduid/webapp/authn/app.py | 2 +- src/eduid/webapp/authn/tests/test_authn.py | 4 +- src/eduid/webapp/common/api/debug.py | 2 +- src/eduid/webapp/common/api/schemas/csrf.py | 8 +- src/eduid/webapp/common/api/schemas/email.py | 4 +- .../webapp/common/api/schemas/password.py | 6 +- .../webapp/common/api/schemas/validators.py | 6 +- src/eduid/webapp/common/api/validation.py | 3 +- src/eduid/webapp/common/authn/acs_registry.py | 3 +- .../common/authn/tests/test_middleware.py | 2 +- .../webapp/common/authn/tests/test_vccs.py | 8 +- .../session/tests/test_eduid_session.py | 4 +- src/eduid/webapp/freja_eid/app.py | 2 +- src/eduid/webapp/freja_eid/tests/test_app.py | 2 +- src/eduid/webapp/group_management/app.py | 2 +- src/eduid/webapp/idp/decorators.py | 5 +- src/eduid/webapp/idp/schemas.py | 2 +- src/eduid/webapp/idp/tests/test_api.py | 4 +- src/eduid/webapp/idp/tests/test_login.py | 2 +- src/eduid/webapp/jsconfig/app.py | 2 +- src/eduid/webapp/ladok/app.py | 2 +- src/eduid/webapp/oidc_proofing/app.py | 2 +- .../webapp/oidc_proofing/tests/test_app.py | 2 +- src/eduid/webapp/orcid/app.py | 2 +- src/eduid/webapp/personal_data/app.py | 2 +- .../webapp/personal_data/tests/test_app.py | 2 +- src/eduid/webapp/phone/app.py | 2 +- src/eduid/webapp/phone/schemas.py | 4 +- src/eduid/webapp/reset_password/app.py | 2 +- .../webapp/reset_password/tests/test_app.py | 2 +- src/eduid/webapp/security/app.py | 2 +- src/eduid/webapp/security/schemas.py | 4 +- src/eduid/webapp/signup/app.py | 2 +- src/eduid/webapp/signup/schemas.py | 12 +- src/eduid/webapp/signup/tests/test_app.py | 2 +- src/eduid/webapp/support/app.py | 4 +- src/eduid/webapp/svipe_id/app.py | 2 +- src/eduid/webapp/svipe_id/tests/test_app.py | 2 +- src/eduid/workers/am/tests/test_am.py | 4 +- src/eduid/workers/am/tests/test_tasks.py | 4 +- src/eduid/workers/amapi/context_request.py | 3 +- src/eduid/workers/amapi/testing.py | 2 +- .../workers/amapi/tests/test_middleware.py | 3 +- src/eduid/workers/amapi/tests/test_user.py | 5 +- src/eduid/workers/lookup_mobile/decorators.py | 4 +- src/eduid/workers/msg/decorators.py | 4 +- 63 files changed, 213 insertions(+), 170 deletions(-) diff --git a/ruff.toml b/ruff.toml index 74a132870..3f1bfa947 100644 --- a/ruff.toml +++ b/ruff.toml @@ -3,6 +3,6 @@ line-length = 120 target-version = "py310" [lint] -select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN001"] +select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN0", ] ignore = ["E501"] diff --git a/src/eduid/common/clients/amapi_client/amapi_client.py b/src/eduid/common/clients/amapi_client/amapi_client.py index 1d726debe..134a0f015 100644 --- a/src/eduid/common/clients/amapi_client/amapi_client.py +++ b/src/eduid/common/clients/amapi_client/amapi_client.py @@ -20,7 +20,7 @@ class AMAPIClient(GNAPClient): - def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs): + def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs: Any): super().__init__(auth_data=auth_data, verify=verify_tls, **kwargs) self.amapi_url = amapi_url diff --git a/src/eduid/common/clients/gnap_client/async_client.py b/src/eduid/common/clients/gnap_client/async_client.py index 24ba8c425..aa2c5a54b 100644 --- a/src/eduid/common/clients/gnap_client/async_client.py +++ b/src/eduid/common/clients/gnap_client/async_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -11,7 +12,7 @@ class AsyncGNAPClient(httpx.AsyncClient, GNAPBearerTokenMixin): - def __init__(self, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any): if "event_hooks" not in kwargs: kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]} diff --git a/src/eduid/common/clients/gnap_client/sync_client.py b/src/eduid/common/clients/gnap_client/sync_client.py index 79ff1cc1a..cbc2728e4 100644 --- a/src/eduid/common/clients/gnap_client/sync_client.py +++ b/src/eduid/common/clients/gnap_client/sync_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -11,7 +12,7 @@ class GNAPClient(httpx.Client, GNAPBearerTokenMixin): - def __init__(self, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any): if "event_hooks" not in kwargs: kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]} diff --git a/src/eduid/common/clients/scim_client/scim_client.py b/src/eduid/common/clients/scim_client/scim_client.py index d66eaa2a7..6eee28f96 100644 --- a/src/eduid/common/clients/scim_client/scim_client.py +++ b/src/eduid/common/clients/scim_client/scim_client.py @@ -1,4 +1,5 @@ import logging +from typing import Any from uuid import UUID import httpx @@ -20,7 +21,7 @@ class SCIMError(Exception): class SCIMClient(GNAPClient): - def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs): + def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs: Any): super().__init__(auth_data=auth_data, **kwargs) self.event_hooks["request"].append(self._add_accept_header) self.scim_api_url = scim_api_url diff --git a/src/eduid/common/config/parsers/decorators.py b/src/eduid/common/config/parsers/decorators.py index 39caa8fb3..d948e1d1c 100644 --- a/src/eduid/common/config/parsers/decorators.py +++ b/src/eduid/common/config/parsers/decorators.py @@ -13,7 +13,7 @@ def decrypt(f: Callable): @wraps(f) - def decrypt_decorator(*args, **kwargs): + def decrypt_decorator(*args: Any, **kwargs: Any): config_dict = f(*args, **kwargs) decrypted_config_dict = decrypt_config(config_dict) return decrypted_config_dict @@ -85,7 +85,7 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]: def interpolate(f: Callable): @wraps(f) - def interpolation_decorator(*args, **kwargs): + def interpolation_decorator(*args: Any, **kwargs: Any): config_dict = f(*args, **kwargs) interpolated_config_dict = interpolate_config(config_dict) for key in list(interpolated_config_dict.keys()): diff --git a/src/eduid/common/decorators.py b/src/eduid/common/decorators.py index 661613b01..9f49bcd60 100644 --- a/src/eduid/common/decorators.py +++ b/src/eduid/common/decorators.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Callable from functools import wraps +from typing import Any # https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488 @@ -28,7 +29,7 @@ def decorator(func1: Callable): fmt1 = "Call to deprecated function {name} ({reason})." @wraps(func1) - def new_func1(*args, **kwargs): + def new_func1(*args: Any, **kwargs: Any): warnings.simplefilter("always", DeprecationWarning) warnings.warn( fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2 @@ -57,7 +58,7 @@ def new_func1(*args, **kwargs): fmt2 = "Call to deprecated function {name}." @wraps(func2) - def new_func2(*args, **kwargs): + def new_func2(*args: Any, **kwargs: Any): warnings.simplefilter("always", DeprecationWarning) warnings.warn(fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2) warnings.simplefilter("default", DeprecationWarning) diff --git a/src/eduid/common/fastapi/context_request.py b/src/eduid/common/fastapi/context_request.py index c4953b007..58beee775 100644 --- a/src/eduid/common/fastapi/context_request.py +++ b/src/eduid/common/fastapi/context_request.py @@ -1,5 +1,6 @@ from collections.abc import Callable from dataclasses import asdict, dataclass +from typing import Any from fastapi import Request, Response from fastapi.routing import APIRoute @@ -12,7 +13,7 @@ def to_dict(self): class ContextRequest(Request): - def __init__(self, context_class: type[Context], *args, **kwargs): + def __init__(self, context_class: type[Context], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.contextClass = context_class diff --git a/src/eduid/common/fastapi/exceptions.py b/src/eduid/common/fastapi/exceptions.py index e150698a9..d8a1958a6 100644 --- a/src/eduid/common/fastapi/exceptions.py +++ b/src/eduid/common/fastapi/exceptions.py @@ -2,6 +2,7 @@ import logging import uuid +from typing import Any from fastapi import Request, status from fastapi.exception_handlers import http_exception_handler @@ -74,28 +75,28 @@ def extra_headers(self, headers: dict): class BadRequest(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_400_BAD_REQUEST, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Bad Request" class Unauthorized(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Unauthorized request" class NotFound(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_404_NOT_FOUND, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Resource not found" class MethodNotAllowedMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, **kwargs) if not self.error_detail.detail: allowed_methods = kwargs.get("allowed_methods") @@ -103,7 +104,7 @@ def __init__(self, **kwargs): class UnsupportedMediaTypeMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request was made with an unsupported media type" diff --git a/src/eduid/queue/db/client.py b/src/eduid/queue/db/client.py index 82034a7dc..3accb9ccf 100644 --- a/src/eduid/queue/db/client.py +++ b/src/eduid/queue/db/client.py @@ -1,5 +1,6 @@ import logging from dataclasses import replace +from typing import Any from bson import ObjectId @@ -15,7 +16,7 @@ class QueuePayloadMixin: - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.handlers: dict[str, type[Payload]] = dict() def register_handler(self, payload: type[Payload]) -> None: diff --git a/src/eduid/queue/decorators.py b/src/eduid/queue/decorators.py index b9b84ba8d..7fdaa933b 100644 --- a/src/eduid/queue/decorators.py +++ b/src/eduid/queue/decorators.py @@ -25,7 +25,7 @@ def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f - def audit(*args, **kwargs): + def audit(*args: Any, **kwargs: Any): ret = f(*args, **kwargs) if not isclass(ret) and self.collection: # we can't save class objects in mongodb date = utc_now() @@ -52,7 +52,7 @@ def disable(cls): cls.enabled = False @staticmethod - def _filter(func: str, data: Any, *args, **kwargs): + def _filter(func: str, data: Any, *args: Any, **kwargs: Any): if data is False: return data if func == "_get_navet_data": diff --git a/src/eduid/satosa/scimapi/serve_static.py b/src/eduid/satosa/scimapi/serve_static.py index a2bf79a9a..dea589e0d 100644 --- a/src/eduid/satosa/scimapi/serve_static.py +++ b/src/eduid/satosa/scimapi/serve_static.py @@ -4,6 +4,7 @@ import logging import mimetypes +from typing import Any from satosa.context import Context from satosa.micro_services.base import RequestMicroService @@ -29,7 +30,7 @@ class ServeStatic(RequestMicroService): logprefix = "SERVE_STATIC_SERVICE:" - def __init__(self, config: SATOSAConfig, *args, **kwargs): + def __init__(self, config: SATOSAConfig, *args: Any, **kwargs: Any): """ :type config: satosa.satosa_config.SATOSAConfig :param config: The SATOSA proxy config diff --git a/src/eduid/scimapi/context.py b/src/eduid/scimapi/context.py index a16b07945..a301105ca 100644 --- a/src/eduid/scimapi/context.py +++ b/src/eduid/scimapi/context.py @@ -2,6 +2,7 @@ import logging.config from dataclasses import dataclass, field from datetime import datetime +from typing import Any from uuid import UUID from eduid.common.config.base import DataOwnerConfig, DataOwnerName @@ -106,7 +107,7 @@ def get_invitedb(self, data_owner: DataOwnerName) -> ScimApiInviteDB | None: def get_eventdb(self, data_owner: DataOwnerName) -> ScimApiEventDB | None: return self._get_data_owner_dbs(data_owner=data_owner).eventdb - def url_for(self, *args) -> str: + def url_for(self, *args: Any) -> str: url = self.base_url for arg in args: url = urlappend(url, f"{arg}") diff --git a/src/eduid/scimapi/exceptions.py b/src/eduid/scimapi/exceptions.py index 562378f2e..0a24ae3f0 100644 --- a/src/eduid/scimapi/exceptions.py +++ b/src/eduid/scimapi/exceptions.py @@ -2,6 +2,7 @@ import logging import uuid +from typing import Any from fastapi import Request, status from fastapi.encoders import jsonable_encoder @@ -90,28 +91,28 @@ def extra_headers(self, headers: dict): class BadRequest(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_400_BAD_REQUEST, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Bad Request" class Unauthorized(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Unauthorized request" class NotFound(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_404_NOT_FOUND, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Resource not found" class MethodNotAllowedMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, **kwargs) if not self.error_detail.detail: allowed_methods = kwargs.get("allowed_methods") @@ -119,14 +120,14 @@ def __init__(self, **kwargs): class UnsupportedMediaTypeMalformed(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request was made with an unsupported media type" class Conflict(HTTPErrorDetail): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(status_code=status.HTTP_409_CONFLICT, **kwargs) if not self.error_detail.detail: self.error_detail.detail = "Request conflicts with the current state" diff --git a/src/eduid/scimapi/utils.py b/src/eduid/scimapi/utils.py index 0f04ec0a2..c9a05e23c 100644 --- a/src/eduid/scimapi/utils.py +++ b/src/eduid/scimapi/utils.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Callable -from typing import AnyStr, TypeVar +from typing import Any, AnyStr, TypeVar from uuid import uuid4 from jwcrypto import jwk @@ -63,7 +63,7 @@ def load_jwks(config: ScimApiConfig) -> jwk.JWKSet: def retryable_db_write(func: Callable): @functools.wraps(func) - def wrapper_run_func(*args, **kwargs): + def wrapper_run_func(*args: Any, **kwargs: Any): max_retries = 10 retry = 0 while True: diff --git a/src/eduid/userdb/tests/test_group_management.py b/src/eduid/userdb/tests/test_group_management.py index a20aa6323..bdf349f6b 100644 --- a/src/eduid/userdb/tests/test_group_management.py +++ b/src/eduid/userdb/tests/test_group_management.py @@ -13,7 +13,7 @@ class TestResetGroupInviteStateDB(MongoTestCase): user: User - def setUp(self, **kwargs): + def setUp(self): super().setUp() self.user = UserFixtures().mocked_user_standard self.invite_state_db = GroupManagementInviteStateDB(self.tmp_db.uri) diff --git a/src/eduid/userdb/tests/test_userdb.py b/src/eduid/userdb/tests/test_userdb.py index cc8448672..d38940091 100644 --- a/src/eduid/userdb/tests/test_userdb.py +++ b/src/eduid/userdb/tests/test_userdb.py @@ -1,10 +1,12 @@ import logging +from typing import Any import bson import pytest from eduid.common.testing_base import normalised_data from eduid.userdb import User +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.exceptions import UserDoesNotExist, UserOutOfSync from eduid.userdb.fixtures.passwords import signup_password from eduid.userdb.fixtures.users import UserFixtures @@ -15,7 +17,7 @@ class TestUserDB(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): self.user = UserFixtures().mocked_user_standard super().setUp(am_users=[self.user], **kwargs) @@ -71,9 +73,9 @@ def test_get_user_by_eppn_not_found(self): class UserMissingMeta(MongoTestCase): user: User - def setUp(self, *args, **kwargs): + def setUp(self): self.user = UserFixtures().mocked_user_standard - super().setUp(*args, am_users=[self.user], **kwargs) + super().setUp(am_users=[self.user]) self._remove_meta_from_user_in_db(self.user) @@ -103,7 +105,7 @@ def test_update_user_old(self): class UpdateUser(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): _users = UserFixtures() self.user = _users.mocked_user_standard super().setUp(am_users=[self.user, _users.mocked_user_standard_2], **kwargs) @@ -133,25 +135,29 @@ def test_ok(self): class TestUserDB_mail(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "mail-test1", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "passwords": [signup_password.to_dict()], - } - - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "mail-test2", - "mailAliases": [ - {"email": "test2@gmail.com", "primary": True, "verified": True}, - {"email": "test@gmail.com", "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "mail-test1", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "passwords": [signup_password.to_dict()], + } + ) + + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "mail-test2", + "mailAliases": [ + {"email": "test2@gmail.com", "primary": True, "verified": True}, + {"email": "test@gmail.com", "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) @@ -179,29 +185,33 @@ def test_get_user_by_mail_multiple(self): class TestUserDB_phone(MongoTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "phone-test1", - "mail": "kalle@example.com", - "phone": [ - {"number": "+11111111111", "primary": True, "verified": True}, - {"number": "+22222222222", "primary": False, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "phone-test2", - "mail": "anka@example.com", - "phone": [ - {"number": "+11111111111", "primary": True, "verified": False}, - {"number": "+22222222222", "primary": False, "verified": False}, - {"number": "+33333333333", "primary": False, "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "phone-test1", + "mail": "kalle@example.com", + "phone": [ + {"number": "+11111111111", "primary": True, "verified": True}, + {"number": "+22222222222", "primary": False, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "phone-test2", + "mail": "anka@example.com", + "phone": [ + {"number": "+11111111111", "primary": True, "verified": False}, + {"number": "+22222222222", "primary": False, "verified": False}, + {"number": "+33333333333", "primary": False, "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) @@ -237,36 +247,41 @@ def test_get_user_by_phone_multiple(self): class TestUserDB_nin(MongoTestCase): # TODO: Keep for a while to make sure the conversion to identities work as expected - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) - data1 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test1", - "mail": "kalle@example.com", - "nins": [ - {"number": "11111111111", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - data2 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test2", - "mail": "anka@example.com", - "nins": [ - {"number": "22222222222", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } - - data3 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test3", - "mail": "anka@example.com", - "nins": [ - {"number": "33333333333", "primary": False, "verified": False}, - ], - "passwords": [signup_password.to_dict()], - } + data1: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test1", + "mail": "kalle@example.com", + "nins": [ + {"number": "11111111111", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data2: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test2", + "mail": "anka@example.com", + "nins": [ + {"number": "22222222222", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) + data3: TUserDbDocument = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test3", + "mail": "anka@example.com", + "nins": [ + {"number": "33333333333", "primary": False, "verified": False}, + ], + "passwords": [signup_password.to_dict()], + } + ) self.user1 = User.from_dict(data1) self.user2 = User.from_dict(data2) diff --git a/src/eduid/webapp/authn/app.py b/src/eduid/webapp/authn/app.py index eb27a3fd5..52ac71c13 100644 --- a/src/eduid/webapp/authn/app.py +++ b/src/eduid/webapp/authn/app.py @@ -10,7 +10,7 @@ class AuthnApp(EduIDBaseApp): - def __init__(self, config: AuthnConfig, **kwargs): + def __init__(self, config: AuthnConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/authn/tests/test_authn.py b/src/eduid/webapp/authn/tests/test_authn.py index 11b7481f2..23af76efb 100644 --- a/src/eduid/webapp/authn/tests/test_authn.py +++ b/src/eduid/webapp/authn/tests/test_authn.py @@ -251,7 +251,7 @@ class AuthnAPITestCase(AuthnAPITestBase): app: AuthnApp - def setUp(self, **kwargs): + def setUp(self, **kwargs: Any): # type: ignore[override] super().setUp(users=["hubba-bubba", "hubba-fooo"], **kwargs) def test_login_authn(self): @@ -309,7 +309,7 @@ def _signup_authn_user(self, eppn: str): class AuthnTestApp(AuthnBaseApp): - def __init__(self, config: AuthnConfig, **kwargs): + def __init__(self, config: AuthnConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/common/api/debug.py b/src/eduid/webapp/common/api/debug.py index 005cec47b..dbf1e7cff 100644 --- a/src/eduid/webapp/common/api/debug.py +++ b/src/eduid/webapp/common/api/debug.py @@ -22,7 +22,7 @@ def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> I errorlog = environ["wsgi.errors"] pprint.pprint(("REQUEST", environ), stream=errorlog) - def log_response(status: str, headers: list[tuple[str, str]], *args): + def log_response(status: str, headers: list[tuple[str, str]], *args: Any): pprint.pprint(("RESPONSE", status, headers), stream=errorlog) return start_response(status, headers, *args) diff --git a/src/eduid/webapp/common/api/schemas/csrf.py b/src/eduid/webapp/common/api/schemas/csrf.py index d8c1c3436..c8893dc95 100644 --- a/src/eduid/webapp/common/api/schemas/csrf.py +++ b/src/eduid/webapp/common/api/schemas/csrf.py @@ -16,7 +16,7 @@ class CSRFRequestMixin(Schema): csrf_token = fields.String(required=True) @validates("csrf_token") - def validate_csrf_token(self, value: str, **kwargs): + def validate_csrf_token(self, value: str, **kwargs: Any): custom_header = request.headers.get("X-Requested-With") if custom_header != "XMLHttpRequest": # TODO: move value to config current_app.logger.error("CSRF check: missing custom X-Requested-With header") @@ -26,13 +26,13 @@ def validate_csrf_token(self, value: str, **kwargs): logger.debug(f"Validated CSRF token in session: {session.get_csrf_token()}") @post_load - def post_processing(self, in_data: Any, **kwargs): + def post_processing(self, in_data: Any, **kwargs: Any): # Remove token from data forwarded to views in_data = self.remove_csrf_token(in_data) return in_data @staticmethod - def remove_csrf_token(in_data: Any, **kwargs): + def remove_csrf_token(in_data: Any, **kwargs: Any): del in_data["csrf_token"] return in_data @@ -41,7 +41,7 @@ class CSRFResponseMixin(Schema): csrf_token = fields.String(required=True) @pre_dump - def get_csrf_token(self, out_data: Any, **kwargs): + def get_csrf_token(self, out_data: Any, **kwargs: Any): # Generate a new csrf token for every response out_data["csrf_token"] = session.new_csrf_token() logger.debug(f'Generated new CSRF token in CSRFResponseMixin: {out_data["csrf_token"]}') diff --git a/src/eduid/webapp/common/api/schemas/email.py b/src/eduid/webapp/common/api/schemas/email.py index 7dcf341ce..f36852e8d 100644 --- a/src/eduid/webapp/common/api/schemas/email.py +++ b/src/eduid/webapp/common/api/schemas/email.py @@ -10,12 +10,12 @@ class LowercaseEmail(Email): Email field that serializes and deserializes to a lower case string. """ - def _serialize(self, value: str | bytes, attr: Any, obj: Any, **kwargs): + def _serialize(self, value: str | bytes, attr: Any, obj: Any, **kwargs: Any): _value = super()._serialize(value, attr, obj, **kwargs) if _value is None: return None return _value.lower() - def _deserialize(self, value: str | bytes, attr: Any, data: Any, **kwargs): + def _deserialize(self, value: str | bytes, attr: Any, data: Any, **kwargs: Any): _value = super()._deserialize(value, attr, data, **kwargs) return _value.lower() diff --git a/src/eduid/webapp/common/api/schemas/password.py b/src/eduid/webapp/common/api/schemas/password.py index 3faf75f49..11b5fa71d 100644 --- a/src/eduid/webapp/common/api/schemas/password.py +++ b/src/eduid/webapp/common/api/schemas/password.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import Schema, ValidationError __author__ = "lundberg" @@ -11,13 +13,13 @@ class Meta: min_entropy: int | None = None min_score: int | None = None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): self.Meta.zxcvbn_terms = kwargs.pop("zxcvbn_terms", []) self.Meta.min_entropy = kwargs.pop("min_entropy") self.Meta.min_score = kwargs.pop("min_score") super().__init__(*args, **kwargs) - def validate_password(self, password: str, **kwargs): + def validate_password(self, password: str, **kwargs: Any): """ :param password: New password diff --git a/src/eduid/webapp/common/api/schemas/validators.py b/src/eduid/webapp/common/api/schemas/validators.py index 79a124340..700e27108 100644 --- a/src/eduid/webapp/common/api/schemas/validators.py +++ b/src/eduid/webapp/common/api/schemas/validators.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import ValidationError from eduid.webapp.common.api.validation import is_valid_email, is_valid_nin @@ -5,7 +7,7 @@ __author__ = "lundberg" -def validate_nin(nin: str, **kwargs): +def validate_nin(nin: str, **kwargs: Any): """ :param nin: National Identity Number :type nin: string_types @@ -18,7 +20,7 @@ def validate_nin(nin: str, **kwargs): raise ValidationError("nin needs to be formatted as 18|19|20yymmddxxxx") -def validate_email(email: str, **kwargs): +def validate_email(email: str, **kwargs: Any): """ :param email: E-mail address :type email: string_types diff --git a/src/eduid/webapp/common/api/validation.py b/src/eduid/webapp/common/api/validation.py index 3be1f7928..ca345e1e3 100644 --- a/src/eduid/webapp/common/api/validation.py +++ b/src/eduid/webapp/common/api/validation.py @@ -1,6 +1,7 @@ import math import re from collections.abc import Sequence +from typing import Any from zxcvbn import zxcvbn @@ -24,7 +25,7 @@ def is_valid_nin(nin: str) -> bool: raise ValueError("nin needs to be formatted as 18|19|20yymmddxxxx") -def is_valid_email(email: str, **kwargs): +def is_valid_email(email: str, **kwargs: Any): """ :param email: E-mail address :return: True or raises ValueError diff --git a/src/eduid/webapp/common/authn/acs_registry.py b/src/eduid/webapp/common/authn/acs_registry.py index 15e69d697..b280056c3 100644 --- a/src/eduid/webapp/common/authn/acs_registry.py +++ b/src/eduid/webapp/common/authn/acs_registry.py @@ -13,6 +13,7 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum +from typing import Any from flask import current_app from werkzeug.wrappers import Response as WerkzeugResponse @@ -60,7 +61,7 @@ def acs_action(action: Enum): def outer(func: Callable[[ACSArgs], ACSResult]) -> Callable[[ACSArgs], ACSResult]: _actions[action.value] = func - def inner(*args, **kwargs) -> ACSResult: + def inner(*args: Any, **kwargs: Any) -> ACSResult: return func(*args, **kwargs) return inner diff --git a/src/eduid/webapp/common/authn/tests/test_middleware.py b/src/eduid/webapp/common/authn/tests/test_middleware.py index 4ee84080d..06ac5bdd4 100644 --- a/src/eduid/webapp/common/authn/tests/test_middleware.py +++ b/src/eduid/webapp/common/authn/tests/test_middleware.py @@ -10,7 +10,7 @@ class AuthnTestApp(AuthnBaseApp): - def __init__(self, name: str, test_config: Mapping[str, Any], **kwargs): + def __init__(self, name: str, test_config: Mapping[str, Any], **kwargs: Any): # This should be an AuthnConfig instance, but an EduIDBaseAppConfig instance suffices for these # tests and we don't want eduid.webapp.common to depend on eduid.webapp. self.conf = load_config(typ=EduIDBaseAppConfig, app_name=name, ns="webapp", test_config=test_config) diff --git a/src/eduid/webapp/common/authn/tests/test_vccs.py b/src/eduid/webapp/common/authn/tests/test_vccs.py index a6807d3a1..7be1fd1ec 100644 --- a/src/eduid/webapp/common/authn/tests/test_vccs.py +++ b/src/eduid/webapp/common/authn/tests/test_vccs.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from unittest.mock import patch from eduid.userdb.fixtures.users import UserFixtures @@ -12,8 +12,8 @@ class VCCSTestCase(MongoTestCase): user: User - def setUp(self, **kwargs): - super().setUp(am_users=[UserFixtures().new_user_example], **kwargs) + def setUp(self): + super().setUp(am_users=[UserFixtures().new_user_example]) self.vccs_client = cast(VCCSClient, MockVCCSClient()) _user = self.amdb.get_user_by_mail("johnsmith@example.com") assert _user is not None @@ -149,7 +149,7 @@ def test_change_password_error_adding(self): def test_reset_password_error_revoking(self): from eduid.webapp.common.authn.testing import MockVCCSClient - def mock_revoke_creds(*args): + def mock_revoke_creds(*args: Any): raise VCCSClientHTTPError("dummy", 500) with patch.object(MockVCCSClient, "revoke_credentials", mock_revoke_creds): diff --git a/src/eduid/webapp/common/session/tests/test_eduid_session.py b/src/eduid/webapp/common/session/tests/test_eduid_session.py index b6e49c76e..1e775e440 100644 --- a/src/eduid/webapp/common/session/tests/test_eduid_session.py +++ b/src/eduid/webapp/common/session/tests/test_eduid_session.py @@ -17,7 +17,7 @@ class SessionTestConfig(EduIDBaseAppConfig): class SessionTestApp(AuthnBaseApp): - def __init__(self, config: SessionTestConfig, **kwargs): + def __init__(self, config: SessionTestConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config @@ -78,7 +78,7 @@ def logout(): class EduidSessionTests(EduidAPITestCase): app: SessionTestApp - def setUp(self, **kwargs): + def setUp(self, **kwargs: Any): # type: ignore[override] self.test_user_eppn = "hubba-bubba" super().setUp(**kwargs) diff --git a/src/eduid/webapp/freja_eid/app.py b/src/eduid/webapp/freja_eid/app.py index 3514f735c..dc5309a25 100644 --- a/src/eduid/webapp/freja_eid/app.py +++ b/src/eduid/webapp/freja_eid/app.py @@ -16,7 +16,7 @@ class FrejaEIDApp(AuthnBaseApp): - def __init__(self, config: FrejaEIDConfig, **kwargs): + def __init__(self, config: FrejaEIDConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/freja_eid/tests/test_app.py b/src/eduid/webapp/freja_eid/tests/test_app.py index 9e342ea1f..dfce75769 100644 --- a/src/eduid/webapp/freja_eid/tests/test_app.py +++ b/src/eduid/webapp/freja_eid/tests/test_app.py @@ -28,7 +28,7 @@ class FrejaEIDTests(ProofingTests[FrejaEIDApp]): """Base TestCase for those tests that need a full environment setup""" - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs, users=["hubba-bubba", "hubba-baar"]) self.unverified_test_user = self.app.central_userdb.get_user_by_eppn("hubba-baar") diff --git a/src/eduid/webapp/group_management/app.py b/src/eduid/webapp/group_management/app.py index 8ef1ccdde..0bb1667cc 100644 --- a/src/eduid/webapp/group_management/app.py +++ b/src/eduid/webapp/group_management/app.py @@ -16,7 +16,7 @@ class GroupManagementApp(AuthnBaseApp): - def __init__(self, config: GroupManagementConfig, **kwargs): + def __init__(self, config: GroupManagementConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/idp/decorators.py b/src/eduid/webapp/idp/decorators.py index 8c6736ba9..0be7e0206 100644 --- a/src/eduid/webapp/idp/decorators.py +++ b/src/eduid/webapp/idp/decorators.py @@ -1,6 +1,7 @@ import logging from collections.abc import Callable from functools import wraps +from typing import Any from flask import jsonify, request @@ -18,7 +19,7 @@ def require_ticket(f: Callable): @wraps(f) - def require_ticket_decorator(*args, **kwargs): + def require_ticket_decorator(*args: Any, **kwargs: Any): """Decorator to turn the 'ref' parameter sent by the frontend into a ticket (LoginContext)""" if "ref" not in kwargs: logger.debug("Login ref not supplied") @@ -57,7 +58,7 @@ def require_ticket_decorator(*args, **kwargs): def uses_sso_session(f: Callable): @wraps(f) - def uses_sso_session_decorator(*args, **kwargs): + def uses_sso_session_decorator(*args: Any, **kwargs: Any): """Decorator to supply the current SSO session, if one is found and still valid""" kwargs["sso_session"] = get_sso_session() diff --git a/src/eduid/webapp/idp/schemas.py b/src/eduid/webapp/idp/schemas.py index 34431b339..8d2bb2080 100644 --- a/src/eduid/webapp/idp/schemas.py +++ b/src/eduid/webapp/idp/schemas.py @@ -78,7 +78,7 @@ class MfaAuthResponsePayload(EduidSchema, CSRFResponseMixin): class ToUVersions(fields.Field): """Handle list of ToU versions available in the frontend both as comma-separated string (bug) and as list""" - def _deserialize(self, value: Any, attr: str | None, data: Any, **kwargs) -> list[str] | None: + def _deserialize(self, value: Any, attr: str | None, data: Any, **kwargs: Any) -> list[str] | None: if value is None: return None if isinstance(value, str): diff --git a/src/eduid/webapp/idp/tests/test_api.py b/src/eduid/webapp/idp/tests/test_api.py index d5c92cf9b..6f85f084a 100644 --- a/src/eduid/webapp/idp/tests/test_api.py +++ b/src/eduid/webapp/idp/tests/test_api.py @@ -96,8 +96,8 @@ class IdPAPITests(EduidAPITestCase[IdPApp]): def setUp( self, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().setUp(*args, **kwargs) self.idp_entity_id = "https://unittest-idp.example.edu/idp.xml" diff --git a/src/eduid/webapp/idp/tests/test_login.py b/src/eduid/webapp/idp/tests/test_login.py index 062dc8fe6..b0282b431 100644 --- a/src/eduid/webapp/idp/tests/test_login.py +++ b/src/eduid/webapp/idp/tests/test_login.py @@ -534,7 +534,7 @@ def test_geo_statistics_fail(self) -> None: class IdPTestLoginAPIManagedAccounts(IdPAPITests): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) self.test_eppn = "ma-12345678" managed_account = self._create_managed_account_user(eppn=self.test_eppn) diff --git a/src/eduid/webapp/jsconfig/app.py b/src/eduid/webapp/jsconfig/app.py index 6eada7295..bc3d78044 100644 --- a/src/eduid/webapp/jsconfig/app.py +++ b/src/eduid/webapp/jsconfig/app.py @@ -9,7 +9,7 @@ class JSConfigApp(EduIDBaseApp): - def __init__(self, config: JSConfigConfig, **kwargs): + def __init__(self, config: JSConfigConfig, **kwargs: Any): kwargs["init_central_userdb"] = False kwargs["static_folder"] = None diff --git a/src/eduid/webapp/ladok/app.py b/src/eduid/webapp/ladok/app.py index d43076471..8de14c3cb 100644 --- a/src/eduid/webapp/ladok/app.py +++ b/src/eduid/webapp/ladok/app.py @@ -15,7 +15,7 @@ class LadokApp(AuthnBaseApp): - def __init__(self, config: LadokConfig, **kwargs): + def __init__(self, config: LadokConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/oidc_proofing/app.py b/src/eduid/webapp/oidc_proofing/app.py index b2f943dc8..97b390eca 100644 --- a/src/eduid/webapp/oidc_proofing/app.py +++ b/src/eduid/webapp/oidc_proofing/app.py @@ -18,7 +18,7 @@ class OIDCProofingApp(AuthnBaseApp): - def __init__(self, config: OIDCProofingConfig, **kwargs): + def __init__(self, config: OIDCProofingConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/oidc_proofing/tests/test_app.py b/src/eduid/webapp/oidc_proofing/tests/test_app.py index 02784f4f7..60435e2b4 100644 --- a/src/eduid/webapp/oidc_proofing/tests/test_app.py +++ b/src/eduid/webapp/oidc_proofing/tests/test_app.py @@ -21,7 +21,7 @@ class OidcProofingTests(EduidAPITestCase): app: OIDCProofingApp - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): self.test_user_eppn = "hubba-baar" self.test_user_nin = "200001023456" self.test_user_wrong_nin = "190001021234" diff --git a/src/eduid/webapp/orcid/app.py b/src/eduid/webapp/orcid/app.py index 8c71fef8a..21b450b3c 100644 --- a/src/eduid/webapp/orcid/app.py +++ b/src/eduid/webapp/orcid/app.py @@ -15,7 +15,7 @@ class OrcidApp(AuthnBaseApp): - def __init__(self, config: OrcidConfig, **kwargs): + def __init__(self, config: OrcidConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/personal_data/app.py b/src/eduid/webapp/personal_data/app.py index 313c8a1ed..b963542ae 100644 --- a/src/eduid/webapp/personal_data/app.py +++ b/src/eduid/webapp/personal_data/app.py @@ -11,7 +11,7 @@ class PersonalDataApp(AuthnBaseApp): - def __init__(self, config: PersonalDataConfig, **kwargs): + def __init__(self, config: PersonalDataConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/personal_data/tests/test_app.py b/src/eduid/webapp/personal_data/tests/test_app.py index 0d9a8ab43..633e86c3d 100644 --- a/src/eduid/webapp/personal_data/tests/test_app.py +++ b/src/eduid/webapp/personal_data/tests/test_app.py @@ -16,7 +16,7 @@ class PersonalDataTests(EduidAPITestCase[PersonalDataApp]): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, copy_user_to_private=True, **kwargs) def load_app(self, config: Mapping[str, Any]) -> PersonalDataApp: diff --git a/src/eduid/webapp/phone/app.py b/src/eduid/webapp/phone/app.py index b844d0256..5159b2f49 100644 --- a/src/eduid/webapp/phone/app.py +++ b/src/eduid/webapp/phone/app.py @@ -15,7 +15,7 @@ class PhoneApp(AuthnBaseApp): - def __init__(self, config: PhoneConfig, **kwargs): + def __init__(self, config: PhoneConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/phone/schemas.py b/src/eduid/webapp/phone/schemas.py index 3235a065b..134417264 100644 --- a/src/eduid/webapp/phone/schemas.py +++ b/src/eduid/webapp/phone/schemas.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import fields, pre_load from eduid.webapp.common.api.schemas.base import EduidSchema, FluxStandardAction @@ -18,7 +20,7 @@ class PhoneSchema(EduidSchema, CSRFRequestMixin): primary = fields.Boolean(attribute="primary") @pre_load - def normalize_phone_number(self, in_data: dict, **kwargs): + def normalize_phone_number(self, in_data: dict, **kwargs: Any): if in_data.get("number"): in_data["number"] = normalize_to_e_164(in_data["number"]) return in_data diff --git a/src/eduid/webapp/reset_password/app.py b/src/eduid/webapp/reset_password/app.py index e401596a2..748137864 100644 --- a/src/eduid/webapp/reset_password/app.py +++ b/src/eduid/webapp/reset_password/app.py @@ -17,7 +17,7 @@ class ResetPasswordApp(EduIDBaseApp): - def __init__(self, config: ResetPasswordConfig, **kwargs): + def __init__(self, config: ResetPasswordConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/reset_password/tests/test_app.py b/src/eduid/webapp/reset_password/tests/test_app.py index df595866f..545aa3a19 100644 --- a/src/eduid/webapp/reset_password/tests/test_app.py +++ b/src/eduid/webapp/reset_password/tests/test_app.py @@ -39,7 +39,7 @@ class ResetPasswordTests(EduidAPITestCase[ResetPasswordApp]): """Base TestCase for those tests that need a full environment setup""" - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) self.other_test_user = UserFixtures().mocked_user_standard_2 diff --git a/src/eduid/webapp/security/app.py b/src/eduid/webapp/security/app.py index df27d26ad..fd9bc1bbb 100644 --- a/src/eduid/webapp/security/app.py +++ b/src/eduid/webapp/security/app.py @@ -17,7 +17,7 @@ class SecurityApp(AuthnBaseApp): - def __init__(self, config: SecurityConfig, **kwargs): + def __init__(self, config: SecurityConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/security/schemas.py b/src/eduid/webapp/security/schemas.py index bab2760c2..f5c56f635 100644 --- a/src/eduid/webapp/security/schemas.py +++ b/src/eduid/webapp/security/schemas.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import ValidationError, fields, validates from eduid.webapp.common.api.schemas.base import EduidSchema, FluxStandardAction @@ -59,7 +61,7 @@ class ChangePasswordSchema(PasswordSchema): authn_id = fields.String(required=False) @validates("new_password") - def validate_custom_password(self, value: str, **kwargs): + def validate_custom_password(self, value: str, **kwargs: Any): # Set a new error message try: self.validate_password(value) diff --git a/src/eduid/webapp/signup/app.py b/src/eduid/webapp/signup/app.py index 24160280b..aa9b75b3b 100644 --- a/src/eduid/webapp/signup/app.py +++ b/src/eduid/webapp/signup/app.py @@ -17,7 +17,7 @@ class SignupApp(EduIDBaseApp): - def __init__(self, config: SignupConfig, **kwargs): + def __init__(self, config: SignupConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/signup/schemas.py b/src/eduid/webapp/signup/schemas.py index 4ec88b7df..504c4458e 100644 --- a/src/eduid/webapp/signup/schemas.py +++ b/src/eduid/webapp/signup/schemas.py @@ -1,3 +1,5 @@ +from typing import Any + from marshmallow import fields, pre_dump from eduid.webapp.common.api.schemas.base import EduidSchema, FluxStandardAction @@ -62,19 +64,19 @@ class Credentials(EduidSchema): payload = fields.Nested(StatusSchema) @pre_dump - def set_already_signed_up(self, data: dict, **kwargs): + def set_already_signed_up(self, data: dict, **kwargs: Any): if data["payload"].get("state"): data["payload"]["state"]["already_signed_up"] = bool(session.common.eppn) return data @pre_dump - def set_tou_version(self, data: dict, **kwargs): + def set_tou_version(self, data: dict, **kwargs: Any): if data["payload"].get("state", {}).get("tou") and data["payload"]["state"]["tou"].get("version") is None: data["payload"]["state"]["tou"]["version"] = current_app.conf.tou_version return data @pre_dump - def throttle_delta_to_seconds(self, out_data: dict, **kwargs): + def throttle_delta_to_seconds(self, out_data: dict, **kwargs: Any): if out_data["payload"].get("state", {}).get("email", {}).get("sent_at"): sent_at = out_data["payload"]["state"]["email"]["sent_at"] throttle_time_left = time_left(sent_at, current_app.conf.throttle_resend).total_seconds() @@ -86,7 +88,7 @@ def throttle_delta_to_seconds(self, out_data: dict, **kwargs): return out_data @pre_dump - def email_verification_timeout_delta_to_seconds(self, out_data: dict, **kwargs): + def email_verification_timeout_delta_to_seconds(self, out_data: dict, **kwargs: Any): if out_data["payload"].get("state", {}).get("email", {}).get("sent_at"): sent_at = out_data["payload"]["state"]["email"]["sent_at"] verification_time_left = time_left(sent_at, current_app.conf.email_verification_timeout).total_seconds() @@ -98,7 +100,7 @@ def email_verification_timeout_delta_to_seconds(self, out_data: dict, **kwargs): return out_data @pre_dump - def bad_attempts_max(self, out_data: dict, **kwargs): + def bad_attempts_max(self, out_data: dict, **kwargs: Any): if out_data["payload"].get("state", {}).get("email"): out_data["payload"]["state"]["email"]["bad_attempts_max"] = ( current_app.conf.email_verification_max_bad_attempts diff --git a/src/eduid/webapp/signup/tests/test_app.py b/src/eduid/webapp/signup/tests/test_app.py index 7f2f672ca..7d9f9f44e 100644 --- a/src/eduid/webapp/signup/tests/test_app.py +++ b/src/eduid/webapp/signup/tests/test_app.py @@ -51,7 +51,7 @@ class SignupResult: class SignupTests(EduidAPITestCase[SignupApp], MockedScimAPIMixin): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs, copy_user_to_private=True) def load_app(self, config: Mapping[str, Any]) -> SignupApp: diff --git a/src/eduid/webapp/support/app.py b/src/eduid/webapp/support/app.py index be1afe9af..83549bc0c 100644 --- a/src/eduid/webapp/support/app.py +++ b/src/eduid/webapp/support/app.py @@ -14,7 +14,7 @@ class SupportApp(AuthnBaseApp): - def __init__(self, config: SupportConfig, **kwargs): + def __init__(self, config: SupportConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config @@ -46,7 +46,7 @@ def dateformat(value: datetime | None, format: str = "%Y-%m-%d"): return value.strftime(format) @app.template_filter("multisort") - def sort_multi(items: list, *operators, **kwargs): + def sort_multi(items: list, *operators: str, **kwargs: bool): # Don't try to sort on missing keys keys = list(operators) # operators is immutable for key in operators: diff --git a/src/eduid/webapp/svipe_id/app.py b/src/eduid/webapp/svipe_id/app.py index 63403013e..694231d9a 100644 --- a/src/eduid/webapp/svipe_id/app.py +++ b/src/eduid/webapp/svipe_id/app.py @@ -16,7 +16,7 @@ class SvipeIdApp(AuthnBaseApp): - def __init__(self, config: SvipeIdConfig, **kwargs): + def __init__(self, config: SvipeIdConfig, **kwargs: Any): super().__init__(config, **kwargs) self.conf = config diff --git a/src/eduid/webapp/svipe_id/tests/test_app.py b/src/eduid/webapp/svipe_id/tests/test_app.py index 4874a358c..55242dc9a 100644 --- a/src/eduid/webapp/svipe_id/tests/test_app.py +++ b/src/eduid/webapp/svipe_id/tests/test_app.py @@ -24,7 +24,7 @@ class SvipeIdTests(ProofingTests[SvipeIdApp]): """Base TestCase for those tests that need a full environment setup""" - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs, users=["hubba-bubba", "hubba-baar"]) self.unverified_test_user = self.app.central_userdb.get_user_by_eppn("hubba-baar") diff --git a/src/eduid/workers/am/tests/test_am.py b/src/eduid/workers/am/tests/test_am.py index 7a9071314..711547930 100644 --- a/src/eduid/workers/am/tests/test_am.py +++ b/src/eduid/workers/am/tests/test_am.py @@ -1,3 +1,5 @@ +from typing import Any + from bson import ObjectId import eduid.userdb @@ -82,7 +84,7 @@ class MessageTest(AMTestCase): transforms 'uid' to its urn:oid representation. """ - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, want_mongo_uri=True, **kwargs) self.private_db = AmTestUserDb(db_uri=self.tmp_db.uri, db_name="eduid_am_test") # register fake AMP plugin named 'test' diff --git a/src/eduid/workers/am/tests/test_tasks.py b/src/eduid/workers/am/tests/test_tasks.py index 33f339b4a..dd36229f8 100644 --- a/src/eduid/workers/am/tests/test_tasks.py +++ b/src/eduid/workers/am/tests/test_tasks.py @@ -1,3 +1,5 @@ +from typing import Any + from bson import ObjectId import eduid.userdb @@ -13,7 +15,7 @@ class TestTasks(AMTestCase): user: User - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): _users = UserFixtures() self.user = _users.mocked_user_standard super().setUp(want_mongo_uri=True, am_users=[self.user, _users.mocked_user_standard_2], **kwargs) diff --git a/src/eduid/workers/amapi/context_request.py b/src/eduid/workers/amapi/context_request.py index 2a054036a..33032bddd 100644 --- a/src/eduid/workers/amapi/context_request.py +++ b/src/eduid/workers/amapi/context_request.py @@ -2,6 +2,7 @@ from collections.abc import Callable from dataclasses import asdict, dataclass +from typing import Any from fastapi import Request, Response from fastapi.routing import APIRoute @@ -14,7 +15,7 @@ def to_dict(self): class ContextRequest(Request): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @property diff --git a/src/eduid/workers/amapi/testing.py b/src/eduid/workers/amapi/testing.py index a3904af25..17d151693 100644 --- a/src/eduid/workers/amapi/testing.py +++ b/src/eduid/workers/amapi/testing.py @@ -11,7 +11,7 @@ class TestAMBase(CommonTestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(*args, **kwargs) self.path = pkg_resources.resource_filename(__name__, "tests/data") diff --git a/src/eduid/workers/amapi/tests/test_middleware.py b/src/eduid/workers/amapi/tests/test_middleware.py index 50cb93990..a33e0c738 100644 --- a/src/eduid/workers/amapi/tests/test_middleware.py +++ b/src/eduid/workers/amapi/tests/test_middleware.py @@ -1,12 +1,13 @@ import fnmatch import unittest +from typing import Any from eduid.workers.amapi.config import EndpointRestriction, SupportedMethod from eduid.workers.amapi.middleware import AuthenticationMiddleware class TestMiddleware(unittest.TestCase): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp() self.middleware = AuthenticationMiddleware diff --git a/src/eduid/workers/amapi/tests/test_user.py b/src/eduid/workers/amapi/tests/test_user.py index 6324b8fce..92308781c 100644 --- a/src/eduid/workers/amapi/tests/test_user.py +++ b/src/eduid/workers/amapi/tests/test_user.py @@ -5,9 +5,8 @@ from bson import ObjectId from fastapi import status -from httpx import Headers +from httpx import Headers, Response from jwcrypto import jwt -from requests import Response from eduid.common.clients.gnap_client.base import GNAPBearerTokenMixin from eduid.userdb.fixtures.users import UserFixtures @@ -17,7 +16,7 @@ class TestUsers(TestAMBase, GNAPBearerTokenMixin): - def setUp(self, *args, **kwargs): + def setUp(self, *args: Any, **kwargs: Any): super().setUp(am_users=[UserFixtures().new_user_example]) def _make_url(self, endpoint: str | None = None) -> str: diff --git a/src/eduid/workers/lookup_mobile/decorators.py b/src/eduid/workers/lookup_mobile/decorators.py index 533cefd96..d8a82db7f 100644 --- a/src/eduid/workers/lookup_mobile/decorators.py +++ b/src/eduid/workers/lookup_mobile/decorators.py @@ -29,7 +29,7 @@ def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f - def audit(*args, **kwargs): + def audit(*args: Any, **kwargs: Any): ret = f(*args, **kwargs) # XXX Ugly hack # The class that uses the decorator needs to have self.conf['MONGO_URI'] and self.transaction_audit set @@ -63,7 +63,7 @@ def enable(cls): def disable(cls): cls.enabled = False - def _filter(self, func: str, data: Any, *args, **kwargs): + def _filter(self, func: str, data: Any, *args: Any, **kwargs: Any): if data is False: return data if func == "find_mobiles_by_NIN": diff --git a/src/eduid/workers/msg/decorators.py b/src/eduid/workers/msg/decorators.py index 012d3e381..582e26ebf 100644 --- a/src/eduid/workers/msg/decorators.py +++ b/src/eduid/workers/msg/decorators.py @@ -20,7 +20,7 @@ def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.enabled: return f - def audit(*args, **kwargs): + def audit(*args: Any, **kwargs: Any): ret = f(*args, **kwargs) if not isclass(ret): # we can't save class objects in mongodb date = datetime.utcnow() @@ -52,7 +52,7 @@ def disable(cls): cls.enabled = False @staticmethod - def _filter(func: str, data: Any, *args, **kwargs): + def _filter(func: str, data: Any, *args: Any, **kwargs: Any): if data is False: return data if func == "_get_navet_data": From 4677f5b05e1e6f813a8e4256fd436dcb2f882a6b Mon Sep 17 00:00:00 2001 From: Lasse Yledahl Date: Mon, 30 Sep 2024 09:11:45 +0000 Subject: [PATCH 03/16] add return type None to functions not returning data --- .../common/clients/amapi_client/testing.py | 2 +- .../common/clients/gnap_client/testing.py | 2 +- .../common/clients/scim_client/testing.py | 2 +- .../common/config/tests/test_config_parser.py | 4 +- .../common/config/tests/test_yaml_parser.py | 14 ++-- src/eduid/common/misc/tests/test_timeutil.py | 2 +- src/eduid/common/rpc/tests/test_msg_relay.py | 14 ++-- src/eduid/common/stats/__init__.py | 6 +- src/eduid/graphdb/db.py | 2 +- src/eduid/graphdb/groupdb/db.py | 2 +- src/eduid/graphdb/testing.py | 4 +- src/eduid/graphdb/tests/test_db.py | 14 ++-- .../webapp/personal_data/tests/test_app.py | 80 +++++++++---------- 13 files changed, 76 insertions(+), 72 deletions(-) diff --git a/src/eduid/common/clients/amapi_client/testing.py b/src/eduid/common/clients/amapi_client/testing.py index d6b911dad..2ea93dd52 100644 --- a/src/eduid/common/clients/amapi_client/testing.py +++ b/src/eduid/common/clients/amapi_client/testing.py @@ -4,7 +4,7 @@ class MockedAMAPIMixin(MockedSyncAuthAPIMixin): - def start_mock_amapi(self, access_token_value: str | None = None): + def start_mock_amapi(self, access_token_value: str | None = None) -> None: self.start_mock_auth_api(access_token_value=access_token_value) self.mocked_users = respx.mock(base_url="http://localhost", assert_all_called=False) diff --git a/src/eduid/common/clients/gnap_client/testing.py b/src/eduid/common/clients/gnap_client/testing.py index 62cf913ce..ae5e7fc99 100644 --- a/src/eduid/common/clients/gnap_client/testing.py +++ b/src/eduid/common/clients/gnap_client/testing.py @@ -7,7 +7,7 @@ class MockedSyncAuthAPIMixin: - def start_mock_auth_api(self, access_token_value: str | None = None): + def start_mock_auth_api(self, access_token_value: str | None = None) -> None: if access_token_value is None: access_token_value = "mock_jwt" self.mocked_auth_api = respx.mock(base_url="http://localhost/auth", assert_all_called=False) diff --git a/src/eduid/common/clients/scim_client/testing.py b/src/eduid/common/clients/scim_client/testing.py index 9a0af7155..5c8ff6e93 100644 --- a/src/eduid/common/clients/scim_client/testing.py +++ b/src/eduid/common/clients/scim_client/testing.py @@ -79,7 +79,7 @@ class MockedScimAPIMixin(MockedSyncAuthAPIMixin): put_user_response = post_user_response - def start_mocked_scim_api(self): + def start_mocked_scim_api(self) -> None: self.start_mock_auth_api() self.mocked_scim_api = respx.mock(base_url="http://localhost/scim", assert_all_called=False) diff --git a/src/eduid/common/config/tests/test_config_parser.py b/src/eduid/common/config/tests/test_config_parser.py index ae52ecb98..e943dcf6a 100644 --- a/src/eduid/common/config/tests/test_config_parser.py +++ b/src/eduid/common/config/tests/test_config_parser.py @@ -8,10 +8,10 @@ class TestInitConfig(unittest.TestCase): - def tearDown(self): + def tearDown(self) -> None: os.environ.clear() - def test_YamlConfigParser(self): + def test_YamlConfigParser(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/test/ns/" os.environ["EDUID_CONFIG_YAML"] = "/config.yaml" parser = _choose_parser() diff --git a/src/eduid/common/config/tests/test_yaml_parser.py b/src/eduid/common/config/tests/test_yaml_parser.py index 7df9545bb..b6bd4297b 100644 --- a/src/eduid/common/config/tests/test_yaml_parser.py +++ b/src/eduid/common/config/tests/test_yaml_parser.py @@ -25,10 +25,10 @@ class TestInitConfig(unittest.TestCase): def setUp(self) -> None: self.data_dir = PurePath(__file__).with_name("data") - def tearDown(self): + def tearDown(self) -> None: os.environ.clear() - def test_YamlConfig(self): + def test_YamlConfig(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/app_one" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -49,7 +49,7 @@ def test_YamlConfig(self): assert config_two.number == 10 assert config_two.only_default == 19 - def test_YamlConfig_interpolation(self): + def test_YamlConfig_interpolation(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_interpolation" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -58,7 +58,7 @@ def test_YamlConfig_interpolation(self): assert config.number == 3 assert config.foo == "hi world" - def test_YamlConfig_missing_value(self): + def test_YamlConfig_missing_value(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_missing_value" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -75,7 +75,7 @@ def test_YamlConfig_missing_value(self): } ], f"Wrong error message: {exc_info.value.errors()}" - def test_YamlConfig_wrong_type(self): + def test_YamlConfig_wrong_type(self) -> None: os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_wrong_type" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml") @@ -92,7 +92,7 @@ def test_YamlConfig_wrong_type(self): } ], f"Wrong error message: {exc_info.value.errors()}" - def test_YamlConfig_unknown_data(self): + def test_YamlConfig_unknown_data(self) -> None: """Unknown data should not be rejected because it is an operational nightmare""" os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_unknown_data" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" @@ -102,7 +102,7 @@ def test_YamlConfig_unknown_data(self): assert config.number == 0xFF assert config.foo == "bar" - def test_YamlConfig_mixed_case_keys(self): + def test_YamlConfig_mixed_case_keys(self) -> None: """For legacy reasons, all keys should be lowercased""" os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_mixed_case_keys" os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common" diff --git a/src/eduid/common/misc/tests/test_timeutil.py b/src/eduid/common/misc/tests/test_timeutil.py index 017c4fa9c..d395394df 100644 --- a/src/eduid/common/misc/tests/test_timeutil.py +++ b/src/eduid/common/misc/tests/test_timeutil.py @@ -4,7 +4,7 @@ class TimeUtilTests(unittest.TestCase): - def test_utc_now(self): + def test_utc_now(self) -> None: t1 = utc_now() t2 = utc_now() assert t2 > t1 diff --git a/src/eduid/common/rpc/tests/test_msg_relay.py b/src/eduid/common/rpc/tests/test_msg_relay.py index 65868273a..70bbd9ac3 100644 --- a/src/eduid/common/rpc/tests/test_msg_relay.py +++ b/src/eduid/common/rpc/tests/test_msg_relay.py @@ -33,7 +33,7 @@ def _fix_relations_to(relative_nin: str, relations: Mapping[str, Any]) -> list[d return result @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_all_navet_data()} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -41,7 +41,7 @@ def test_get_all_navet_data(self, mock_get_all_navet_data: MagicMock): assert res == NavetData(**self.message_sender.get_devel_all_navet_data()) @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_all_navet_data(identity_number="189001019802")} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -50,7 +50,7 @@ def test_get_all_navet_data_deceased(self, mock_get_all_navet_data: MagicMock): assert res == NavetData(**self.message_sender.get_devel_all_navet_data(identity_number="189001019802")) @patch("eduid.workers.msg.tasks.get_all_navet_data.apply_async") - def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMock): + def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMock) -> None: mock_conf = {"get.return_value": None} ret = Mock(**mock_conf) mock_get_all_navet_data.return_value = ret @@ -58,7 +58,7 @@ def test_get_all_navet_data_none_response(self, mock_get_all_navet_data: MagicMo self.msg_relay.get_all_navet_data(nin="190102031234") @patch("eduid.workers.msg.tasks.get_postal_address.apply_async") - def test_get_postal_address(self, mock_get_postal_address: MagicMock): + def test_get_postal_address(self, mock_get_postal_address: MagicMock) -> None: mock_conf = {"get.return_value": self.message_sender.get_devel_postal_address()} ret = Mock(**mock_conf) mock_get_postal_address.return_value = ret @@ -66,7 +66,7 @@ def test_get_postal_address(self, mock_get_postal_address: MagicMock): assert res == FullPostalAddress(**self.message_sender.get_devel_postal_address()) @patch("eduid.workers.msg.tasks.get_postal_address.apply_async") - def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMock): + def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMock) -> None: mock_conf = {"get.return_value": None} ret = Mock(**mock_conf) mock_get_postal_address.return_value = ret @@ -74,7 +74,7 @@ def test_get_postal_address_none_response(self, mock_get_postal_address: MagicMo self.msg_relay.get_postal_address(nin="190102031234") @patch("eduid.workers.msg.tasks.get_relations_to.apply_async") - def test_get_relations_to(self, mock_get_relations: MagicMock): + def test_get_relations_to(self, mock_get_relations: MagicMock) -> None: relations_to = self._fix_relations_to( relative_nin="194004048989", relations=self.message_sender.get_devel_relations() ) @@ -86,7 +86,7 @@ def test_get_relations_to(self, mock_get_relations: MagicMock): assert res == [RelationType(item) for item in relations_to] @patch("eduid.workers.msg.tasks.get_relations_to.apply_async") - def test_get_relations_to_empty_response(self, mock_get_relations: MagicMock): + def test_get_relations_to_empty_response(self, mock_get_relations: MagicMock) -> None: mock_conf: dict[str, Any] = {"get.return_value": []} ret = Mock(**mock_conf) mock_get_relations.return_value = ret diff --git a/src/eduid/common/stats/__init__.py b/src/eduid/common/stats/__init__.py index 122501cd0..42b27017d 100644 --- a/src/eduid/common/stats/__init__.py +++ b/src/eduid/common/stats/__init__.py @@ -24,7 +24,7 @@ class AppStats(ABC): def count(self, name: str, value: int = 1) -> None: pass - def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: pass @@ -46,7 +46,7 @@ def count(self, name: str, value: int = 1) -> None: name = f"{self.prefix!s}.{name!s}" self.logger.info(f"No-op stats count: {name!r} {value!r}") - def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: if self.logger: if self.prefix: name = f"{self.prefix!s}.{name!s}" @@ -65,7 +65,7 @@ def count(self, name: str, value: int = 1) -> None: # for .count self.client.incr(f"{name}.count", count=value) - def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False): + def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False) -> None: self.client.gauge(f"{name}.gauge", value=value, rate=rate, delta=delta) diff --git a/src/eduid/graphdb/db.py b/src/eduid/graphdb/db.py index bb5fbf868..35c4d5295 100644 --- a/src/eduid/graphdb/db.py +++ b/src/eduid/graphdb/db.py @@ -69,7 +69,7 @@ def sanitized_uri(self) -> str: def driver(self) -> Driver: return self._driver - def close(self): + def close(self) -> None: self.driver.close() diff --git a/src/eduid/graphdb/groupdb/db.py b/src/eduid/graphdb/groupdb/db.py index 8dd4b5ce7..cbcb95000 100644 --- a/src/eduid/graphdb/groupdb/db.py +++ b/src/eduid/graphdb/groupdb/db.py @@ -37,7 +37,7 @@ def __init__(self, db_uri: str, scope: str, config: dict[str, Any] | None = None super().__init__(db_uri=db_uri, config=config) self._scope = scope - def db_setup(self): + def db_setup(self) -> None: with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session: # new index creation syntax in neo4j >=5.0 statements = [ diff --git a/src/eduid/graphdb/testing.py b/src/eduid/graphdb/testing.py index 06586f641..4576b1b5c 100644 --- a/src/eduid/graphdb/testing.py +++ b/src/eduid/graphdb/testing.py @@ -104,7 +104,7 @@ def https_port(self): def bolt_port(self): return self._bolt_port - def purge_db(self): + def purge_db(self) -> None: q = """ MATCH (n) DETACH DELETE n @@ -135,5 +135,5 @@ def setUpClass(cls) -> None: cls.neo4j_instance = Neo4jTemporaryInstance.get_instance(max_retry_seconds=60) cls.neo4jdb = cls.neo4j_instance.conn - def tearDown(self): + def tearDown(self) -> None: self.neo4j_instance.purge_db() diff --git a/src/eduid/graphdb/tests/test_db.py b/src/eduid/graphdb/tests/test_db.py index b5f4a44a2..eae9cc577 100644 --- a/src/eduid/graphdb/tests/test_db.py +++ b/src/eduid/graphdb/tests/test_db.py @@ -9,12 +9,14 @@ class TestNeo4jDB(Neo4jTestCase): - def test_create_db(self): + def test_create_db(self) -> None: with self.neo4jdb.driver.session() as session: session.run("CREATE (n:Test $props)", props={"name": "test node", "testing": True}) with self.neo4jdb.driver.session() as session: result = session.run("MATCH (n {name: $name})RETURN n.testing", name="test node") - self.assertTrue(result.single().value()) + single = result.single() + assert single + self.assertTrue(single.value()) class TestBaseGraphDB(Neo4jTestCase): @@ -22,12 +24,12 @@ class TestDB(BaseGraphDB): def __init__(self, db_uri: str, config: dict[str, Any] | None = None): super().__init__(db_uri, config=config) - def db_setup(self): + def db_setup(self) -> None: with self._db.driver.session() as session: session.run("CREATE CONSTRAINT ON (n:Test) ASSERT n.name IS UNIQUE") session.run("CREATE INDEX FOR (n:Test) ON (n.testing)") - def test_base_db(self): + def test_base_db(self) -> None: db_uri = self.neo4jdb.db_uri config = {"encrypted": False, "auth": basic_auth("neo4j", "testingtesting")} @@ -36,4 +38,6 @@ def test_base_db(self): session.run("CREATE (n:Test $props)", props={"name": "test node", "testing": True}) with test_db._db.driver.session() as session: result = session.run("MATCH (n {name: $name})RETURN n.testing", name="test node") - self.assertTrue(result.single().value()) + single = result.single() + assert single + self.assertTrue(single.value()) diff --git a/src/eduid/webapp/personal_data/tests/test_app.py b/src/eduid/webapp/personal_data/tests/test_app.py index 633e86c3d..e1237f4f4 100644 --- a/src/eduid/webapp/personal_data/tests/test_app.py +++ b/src/eduid/webapp/personal_data/tests/test_app.py @@ -16,7 +16,7 @@ class PersonalDataTests(EduidAPITestCase[PersonalDataApp]): - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: super().setUp(*args, copy_user_to_private=True, **kwargs) def load_app(self, config: Mapping[str, Any]) -> PersonalDataApp: @@ -193,7 +193,7 @@ def _get_user_identities(self, eppn: str | None = None): # actual test methods - def test_get_user(self): + def test_get_user(self) -> None: user_data = self._get_user() self.assertEqual(user_data["type"], "GET_PERSONAL_DATA_USER_SUCCESS") self.assertEqual(user_data["payload"]["given_name"], "John") @@ -203,11 +203,11 @@ def test_get_user(self): self.assertIsNotNone(self.test_user.to_dict().get("passwords")) self.assertIsNone(user_data["payload"].get("passwords")) - def test_get_unknown_user(self): + def test_get_unknown_user(self) -> None: with self.assertRaises(ApiException): self._get_user(eppn="fooo-fooo") - def test_get_user_all_data(self): + def test_get_user_all_data(self) -> None: response = self._get_user_all_data(eppn="hubba-bubba") expected_payload = { "emails": [ @@ -244,11 +244,11 @@ def test_get_user_all_data(self): assert self.test_user.to_dict().get("passwords") is not None assert user_data["payload"].get("passwords") is None - def test_get_unknown_user_all_data(self): + def test_get_unknown_user_all_data(self) -> None: with self.assertRaises(ApiException): self._get_user_all_data(eppn="fooo-fooo") - def test_post_user(self): + def test_post_user(self) -> None: response = self._post_user(verified_user=False) expected_payload = { "surname": "Johnson", @@ -257,7 +257,7 @@ def test_post_user(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_SUCCESS", payload=expected_payload) - def test_post_user_name(self): + def test_post_user_name(self) -> None: response = self._post_user_name(verified_user=False) expected_payload = { "surname": "Johnson", @@ -265,7 +265,7 @@ def test_post_user_name(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_NAME_SUCCESS", payload=expected_payload) - def test_post_user_language(self): + def test_post_user_language(self) -> None: response = self._post_user_language(verified_user=False) expected_payload = { "language": "en", @@ -274,7 +274,7 @@ def test_post_user_language(self): response, type_="POST_PERSONAL_DATA_USER_LANGUAGE_SUCCESS", payload=expected_payload ) - def test_set_chosen_given_name_and_language_verified_user(self): + def test_set_chosen_given_name_and_language_verified_user(self) -> None: expected_payload = { "surname": "Smith", "given_name": "John", @@ -283,7 +283,7 @@ def test_set_chosen_given_name_and_language_verified_user(self): response = self._post_user(mod_data=expected_payload) self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_SUCCESS", payload=expected_payload) - def test_post_user_name_set_chosen_given_name_verified_user(self): + def test_post_user_name_set_chosen_given_name_verified_user(self) -> None: expected_payload = { "surname": "Smith", "given_name": "John", @@ -291,7 +291,7 @@ def test_post_user_name_set_chosen_given_name_verified_user(self): response = self._post_user_name(mod_data=expected_payload) self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_NAME_SUCCESS", payload=expected_payload) - def test_post_user_language_set_language_verified_user(self): + def test_post_user_language_set_language_verified_user(self) -> None: expected_payload = { "language": "sv", } @@ -300,7 +300,7 @@ def test_post_user_language_set_language_verified_user(self): response, type_="POST_PERSONAL_DATA_USER_LANGUAGE_SUCCESS", payload=expected_payload ) - def test_set_given_name_and_surname_verified_user(self): + def test_set_given_name_and_surname_verified_user(self) -> None: mod_data = { "surname": "Johnson", "given_name": "Peter", @@ -314,7 +314,7 @@ def test_set_given_name_and_surname_verified_user(self): response = self._post_user(mod_data=mod_data) self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_SUCCESS", payload=expected_payload) - def test_post_user_name_set_given_name_and_surname_verified_user(self): + def test_post_user_name_set_given_name_and_surname_verified_user(self) -> None: mod_data = { "surname": "Johnson", "given_name": "Peter", @@ -326,57 +326,57 @@ def test_post_user_name_set_given_name_and_surname_verified_user(self): response = self._post_user_name(mod_data=mod_data) self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_NAME_SUCCESS", payload=expected_payload) - def test_post_user_bad_csrf(self): + def test_post_user_bad_csrf(self) -> None: response = self._post_user(mod_data={"csrf_token": "wrong-token"}) expected_payload = {"error": {"csrf_token": ["CSRF failed to validate"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user__name_bad_csrf(self): + def test_post_user__name_bad_csrf(self) -> None: response = self._post_user_name(mod_data={"csrf_token": "wrong-token"}) expected_payload = {"error": {"csrf_token": ["CSRF failed to validate"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", payload=expected_payload) - def test_post_user_no_given_name(self): + def test_post_user_no_given_name(self) -> None: response = self._post_user(mod_data={"given_name": ""}) expected_payload = {"error": {"given_name": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_name_no_given_name(self): + def test_post_user_name_no_given_name(self) -> None: response = self._post_user_name(mod_data={"given_name": ""}) expected_payload = {"error": {"given_name": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", payload=expected_payload) - def test_post_user_blank_given_name(self): + def test_post_user_blank_given_name(self) -> None: response = self._post_user(mod_data={"given_name": " "}) expected_payload = {"error": {"given_name": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_name_blank_given_name(self): + def test_post_user_name_blank_given_name(self) -> None: response = self._post_user_name(mod_data={"given_name": " "}) expected_payload = {"error": {"given_name": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", payload=expected_payload) - def test_post_user_no_surname(self): + def test_post_user_no_surname(self) -> None: response = self._post_user(mod_data={"surname": ""}) expected_payload = {"error": {"surname": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_name_no_surname(self): + def test_post_user_name_no_surname(self) -> None: response = self._post_user_name(mod_data={"surname": ""}) expected_payload = {"error": {"surname": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", payload=expected_payload) - def test_post_user_blank_surname(self): + def test_post_user_blank_surname(self) -> None: response = self._post_user(mod_data={"surname": " "}) expected_payload = {"error": {"surname": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_name_blank_surname(self): + def test_post_user_name_blank_surname(self) -> None: response = self._post_user_name(mod_data={"surname": " "}) expected_payload = {"error": {"surname": ["pdata.field_required"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", payload=expected_payload) - def test_post_user_with_chosen_given_name(self): + def test_post_user_with_chosen_given_name(self) -> None: response = self._post_user(mod_data={"chosen_given_name": "Peter"}, verified_user=False) expected_payload = { "surname": "Johnson", @@ -386,7 +386,7 @@ def test_post_user_with_chosen_given_name(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_SUCCESS", payload=expected_payload) - def test_post_user_name_with_chosen_given_name(self): + def test_post_user_name_with_chosen_given_name(self) -> None: response = self._post_user_name(mod_data={"chosen_given_name": "Peter"}, verified_user=False) expected_payload = { "surname": "Johnson", @@ -395,19 +395,19 @@ def test_post_user_name_with_chosen_given_name(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_NAME_SUCCESS", payload=expected_payload) - def test_post_user_with_bad_chosen_given_name(self): + def test_post_user_with_bad_chosen_given_name(self) -> None: response = self._post_user(mod_data={"chosen_given_name": "Michael"}, verified_user=False) self._check_error_response( response, type_="POST_PERSONAL_DATA_USER_FAIL", msg=PDataMsg.chosen_given_name_invalid ) - def test_post_user_name_with_bad_chosen_given_name(self): + def test_post_user_name_with_bad_chosen_given_name(self) -> None: response = self._post_user_name(mod_data={"chosen_given_name": "Michael"}, verified_user=False) self._check_error_response( response, type_="POST_PERSONAL_DATA_USER_NAME_FAIL", msg=PDataMsg.chosen_given_name_invalid ) - def test_post_user_to_unset_chosen_given_name(self): + def test_post_user_to_unset_chosen_given_name(self) -> None: # set test user chosen given name self.test_user.chosen_given_name = "Peter" self.app.central_userdb.save(self.test_user) @@ -422,7 +422,7 @@ def test_post_user_to_unset_chosen_given_name(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_SUCCESS", payload=expected_payload) - def test_post_user_name_to_unset_chosen_given_name(self): + def test_post_user_name_to_unset_chosen_given_name(self) -> None: # set test user chosen given name self.test_user.chosen_given_name = "Peter" self.app.central_userdb.save(self.test_user) @@ -436,34 +436,34 @@ def test_post_user_name_to_unset_chosen_given_name(self): } self._check_success_response(response, type_="POST_PERSONAL_DATA_USER_NAME_SUCCESS", payload=expected_payload) - def test_post_user_no_language(self): + def test_post_user_no_language(self) -> None: response = self._post_user(mod_data={"language": ""}) expected_payload = {"error": {"language": ["Language '' is not available"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_language_no_language(self): + def test_post_user_language_no_language(self) -> None: response = self._post_user_language(mod_data={"language": ""}) expected_payload = {"error": {"language": ["Language '' is not available"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_LANGUAGE_FAIL", payload=expected_payload) - def test_post_user_unknown_language(self): + def test_post_user_unknown_language(self) -> None: response = self._post_user(mod_data={"language": "es"}) expected_payload = {"error": {"language": ["Language 'es' is not available"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_FAIL", payload=expected_payload) - def test_post_user_language_unknown_language(self): + def test_post_user_language_unknown_language(self) -> None: response = self._post_user_language(mod_data={"language": "es"}) expected_payload = {"error": {"language": ["Language 'es' is not available"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_USER_LANGUAGE_FAIL", payload=expected_payload) - def test_get_preferences(self): + def test_get_preferences(self) -> None: response = self._get_preferences() expected_payload = {"always_use_security_key": True} self._check_success_response( response=response, type_="GET_PERSONAL_DATA_PREFERENCES_SUCCESS", payload=expected_payload ) - def test_update_preferences(self): + def test_update_preferences(self) -> None: self.set_authn_action( eppn=self.test_user_eppn, frontend_action=FrontendAction.CHANGE_SECURITY_PREFERENCES_AUTHN, @@ -482,7 +482,7 @@ def test_update_preferences(self): user = self.app.central_userdb.get_user_by_eppn(eppn=self.test_user.eppn) assert user.preferences.always_use_security_key is False - def test_update_preferences_no_auth(self): + def test_update_preferences_no_auth(self) -> None: user = self.app.central_userdb.get_user_by_eppn(eppn=self.test_user.eppn) assert user.preferences.always_use_security_key is True @@ -498,22 +498,22 @@ def test_update_preferences_no_auth(self): user = self.app.central_userdb.get_user_by_eppn(eppn=self.test_user.eppn) assert user.preferences.always_use_security_key is True - def test_post_preferences_bad_csrf(self): + def test_post_preferences_bad_csrf(self) -> None: response = self._post_preferences(mod_data={"csrf_token": "wrong-token", "always_use_security_key": True}) expected_payload = {"error": {"csrf_token": ["CSRF failed to validate"]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_PREFERENCES_FAIL", payload=expected_payload) - def test_post_preferences_no_always_use_security_key(self): + def test_post_preferences_no_always_use_security_key(self) -> None: response = self._post_preferences(mod_data={}) expected_payload = {"error": {"always_use_security_key": ["Missing data for required field."]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_PREFERENCES_FAIL", payload=expected_payload) - def test_post_preferences_wrong_always_use_security_key(self): + def test_post_preferences_wrong_always_use_security_key(self) -> None: response = self._post_preferences(mod_data={"always_use_security_key": "tomato"}) expected_payload = {"error": {"always_use_security_key": ["Not a valid boolean."]}} self._check_error_response(response, type_="POST_PERSONAL_DATA_PREFERENCES_FAIL", payload=expected_payload) - def test_get_user_identities(self): + def test_get_user_identities(self) -> None: response = self._get_user_identities() expected_payload = { "identities": { From 34392d288574419f71c1f1afa5e67c1c3b125f6e Mon Sep 17 00:00:00 2001 From: Lasse Yledahl Date: Wed, 2 Oct 2024 06:36:44 +0000 Subject: [PATCH 04/16] add return types to all public functions without type --- ruff.toml | 2 +- src/eduid/common/config/parsers/decorators.py | 4 +- src/eduid/common/decorators.py | 2 +- src/eduid/common/fastapi/context_request.py | 4 +- src/eduid/common/models/jose_models.py | 3 +- src/eduid/common/models/scim_base.py | 8 +- src/eduid/common/rpc/lookup_mobile_relay.py | 4 +- src/eduid/graphdb/db.py | 2 +- src/eduid/graphdb/groupdb/db.py | 8 +- src/eduid/graphdb/testing.py | 8 +- src/eduid/graphdb/tests/test_group.py | 45 +- src/eduid/graphdb/tests/test_groupdb.py | 79 +- src/eduid/maccapi/helpers.py | 10 +- src/eduid/maccapi/middleware.py | 1 + src/eduid/maccapi/routers/status.py | 2 +- src/eduid/maccapi/routers/users.py | 13 +- src/eduid/maccapi/tests/test_maccapi.py | 12 +- src/eduid/queue/db/worker.py | 2 +- src/eduid/queue/testing.py | 4 +- src/eduid/queue/tests/test_client.py | 25 +- src/eduid/queue/tests/test_mail_worker.py | 14 +- src/eduid/queue/tests/test_worker.py | 4 +- src/eduid/queue/workers/base.py | 6 +- src/eduid/queue/workers/mail.py | 4 +- src/eduid/queue/workers/scim_event.py | 2 +- src/eduid/queue/workers/sink.py | 2 +- src/eduid/satosa/scimapi/serve_static.py | 2 +- src/eduid/satosa/scimapi/stepup.py | 4 +- src/eduid/scimapi/middleware.py | 4 +- src/eduid/scimapi/routers/events.py | 8 +- src/eduid/scimapi/routers/groups.py | 16 +- src/eduid/scimapi/routers/invites.py | 12 +- src/eduid/scimapi/routers/users.py | 12 +- src/eduid/scimapi/routers/utils/events.py | 7 +- src/eduid/scimapi/routers/utils/groups.py | 11 + src/eduid/scimapi/routers/utils/invites.py | 17 +- src/eduid/scimapi/routers/utils/status.py | 4 +- src/eduid/scimapi/routers/utils/users.py | 17 +- src/eduid/scimapi/testing.py | 2 +- src/eduid/scimapi/tests/test_authn.py | 37 +- src/eduid/scimapi/tests/test_context.py | 4 +- src/eduid/scimapi/tests/test_groupdb.py | 18 +- src/eduid/scimapi/tests/test_login.py | 12 +- src/eduid/scimapi/tests/test_notifications.py | 6 +- src/eduid/scimapi/tests/test_profile.py | 2 +- src/eduid/scimapi/tests/test_scimbase.py | 2 +- src/eduid/scimapi/tests/test_scimevent.py | 10 +- src/eduid/scimapi/tests/test_scimgroup.py | 95 ++- src/eduid/scimapi/tests/test_sciminvite.py | 38 +- src/eduid/scimapi/tests/test_scimuser.py | 127 +++- src/eduid/scimapi/tests/test_search_filter.py | 10 +- src/eduid/scimapi/utils.py | 4 +- src/eduid/userdb/db/async_db.py | 4 +- src/eduid/userdb/db/sync_db.py | 4 +- src/eduid/userdb/element.py | 2 +- src/eduid/userdb/event.py | 7 +- src/eduid/userdb/locked_identity.py | 4 +- src/eduid/userdb/mail.py | 2 +- src/eduid/userdb/meta.py | 2 +- src/eduid/userdb/reset_password/state.py | 2 +- src/eduid/userdb/scimapi/invitedb.py | 2 +- src/eduid/userdb/scimapi/userdb.py | 6 +- src/eduid/userdb/testing/__init__.py | 12 +- src/eduid/userdb/testing/temp_instance.py | 2 +- src/eduid/userdb/tests/test_app_user.py | 16 +- src/eduid/userdb/tests/test_async_db.py | 24 +- src/eduid/userdb/tests/test_credentials.py | 30 +- src/eduid/userdb/tests/test_db.py | 34 +- src/eduid/userdb/tests/test_element.py | 24 +- src/eduid/userdb/tests/test_event.py | 38 +- src/eduid/userdb/tests/test_exceptions.py | 2 +- .../userdb/tests/test_group_management.py | 8 +- src/eduid/userdb/tests/test_identities.py | 66 +- src/eduid/userdb/tests/test_idp_user.py | 11 +- src/eduid/userdb/tests/test_ladok.py | 2 +- src/eduid/userdb/tests/test_logs.py | 38 +- src/eduid/userdb/tests/test_mail.py | 112 +-- src/eduid/userdb/tests/test_nin.py | 85 ++- src/eduid/userdb/tests/test_orcid.py | 15 +- src/eduid/userdb/tests/test_password.py | 16 +- src/eduid/userdb/tests/test_phone.py | 127 ++-- src/eduid/userdb/tests/test_profile.py | 8 +- src/eduid/userdb/tests/test_proofing.py | 10 +- src/eduid/userdb/tests/test_resetpw.py | 34 +- src/eduid/userdb/tests/test_signup_invite.py | 2 +- src/eduid/userdb/tests/test_signup_user.py | 12 +- src/eduid/userdb/tests/test_support_models.py | 14 +- src/eduid/userdb/tests/test_tou.py | 39 +- src/eduid/userdb/tests/test_u2f.py | 26 +- src/eduid/userdb/tests/test_user.py | 697 +++++++++--------- src/eduid/userdb/tests/test_userdb.py | 96 ++- src/eduid/userdb/tests/test_webauthn.py | 22 +- src/eduid/userdb/userdb.py | 2 +- src/eduid/userdb/util.py | 6 +- src/eduid/vccs/client/tests/test_client.py | 50 +- src/eduid/vccs/server/endpoints/misc.py | 2 +- src/eduid/vccs/server/hasher.py | 32 +- src/eduid/vccs/server/log.py | 5 +- src/eduid/vccs/server/password.py | 2 +- src/eduid/vccs/server/run.py | 4 +- src/eduid/vccs/server/tests/test_db.py | 6 +- src/eduid/webapp/authn/tests/test_authn.py | 54 +- src/eduid/webapp/bankid/tests/test_app.py | 46 +- src/eduid/webapp/common/api/app.py | 3 +- src/eduid/webapp/common/api/debug.py | 6 +- src/eduid/webapp/common/api/exceptions.py | 4 +- src/eduid/webapp/common/api/helpers.py | 2 +- src/eduid/webapp/common/api/messages.py | 2 +- src/eduid/webapp/common/api/request.py | 24 +- src/eduid/webapp/common/api/schemas/csrf.py | 6 +- .../webapp/common/api/schemas/password.py | 2 +- .../webapp/common/api/schemas/validators.py | 4 +- src/eduid/webapp/common/api/testing.py | 6 +- .../webapp/common/api/tests/test_backdoor.py | 15 +- .../common/api/tests/test_decorators.py | 6 +- .../webapp/common/api/tests/test_inputs.py | 58 +- .../webapp/common/api/tests/test_logging.py | 8 +- .../webapp/common/api/tests/test_messages.py | 58 +- .../common/api/tests/test_nin_helpers.py | 42 +- .../common/api/tests/test_validation.py | 12 +- src/eduid/webapp/common/api/validation.py | 2 +- src/eduid/webapp/common/authn/acs_registry.py | 2 +- .../webapp/common/authn/tests/test_cache.py | 28 +- .../common/authn/tests/test_fido_tokens.py | 18 +- .../common/authn/tests/test_middleware.py | 12 +- .../webapp/common/authn/tests/test_vccs.py | 26 +- .../webapp/common/proofing/saml_helpers.py | 3 +- .../webapp/common/session/eduid_session.py | 10 +- src/eduid/webapp/common/session/namespaces.py | 2 +- .../webapp/common/session/redis_session.py | 4 +- src/eduid/webapp/common/session/testing.py | 2 +- .../session/tests/test_eduid_session.py | 31 +- .../common/session/tests/test_namespaces.py | 12 +- .../session/tests/test_redis_session.py | 15 +- .../session/tests/test_session_cookie.py | 2 +- src/eduid/webapp/eidas/tests/test_app.py | 78 +- src/eduid/webapp/email/tests/test_app.py | 64 +- src/eduid/webapp/email/tests/test_msgs.py | 2 +- src/eduid/webapp/email/validators.py | 4 +- src/eduid/webapp/email/verifications.py | 5 +- src/eduid/webapp/email/views.py | 2 +- src/eduid/webapp/freja_eid/tests/test_app.py | 35 +- src/eduid/webapp/group_management/helpers.py | 4 +- .../webapp/group_management/tests/test_app.py | 114 +-- src/eduid/webapp/idp/decorators.py | 4 +- src/eduid/webapp/idp/helpers.py | 4 +- src/eduid/webapp/idp/idp_saml.py | 2 +- src/eduid/webapp/idp/known_device.py | 2 +- src/eduid/webapp/idp/other_device/db.py | 2 +- src/eduid/webapp/idp/tests/test_SSO.py | 70 +- src/eduid/webapp/idp/tests/test_SSOSession.py | 12 +- src/eduid/webapp/idp/tests/test_api.py | 14 +- src/eduid/webapp/idp/tests/test_idPUserDb.py | 26 +- .../webapp/idp/tests/test_known_device.py | 20 +- src/eduid/webapp/idp/tests/test_login.py | 2 +- src/eduid/webapp/idp/tests/test_logout.py | 19 +- src/eduid/webapp/jsconfig/tests/test_app.py | 14 +- src/eduid/webapp/ladok/tests/test_app.py | 22 +- src/eduid/webapp/letter_proofing/ekopost.py | 3 +- src/eduid/webapp/letter_proofing/helpers.py | 4 +- src/eduid/webapp/letter_proofing/pdf.py | 4 +- .../webapp/letter_proofing/tests/test_app.py | 75 +- .../webapp/letter_proofing/tests/test_msgs.py | 2 +- .../webapp/letter_proofing/tests/test_pdf.py | 12 +- .../lookup_mobile_proofing/tests/test_app.py | 24 +- .../tests/test_helpers.py | 2 +- .../lookup_mobile_proofing/tests/test_msgs.py | 2 +- .../webapp/oidc_proofing/tests/test_app.py | 49 +- .../webapp/oidc_proofing/tests/test_msgs.py | 2 +- src/eduid/webapp/oidc_proofing/views.py | 4 +- src/eduid/webapp/orcid/tests/test_app.py | 14 +- src/eduid/webapp/orcid/tests/test_msgs.py | 2 +- src/eduid/webapp/personal_data/validators.py | 4 +- src/eduid/webapp/phone/schemas.py | 2 +- src/eduid/webapp/phone/tests/test_app.py | 68 +- src/eduid/webapp/phone/tests/test_msgs.py | 2 +- src/eduid/webapp/phone/validators.py | 10 +- src/eduid/webapp/reset_password/helpers.py | 4 +- .../webapp/reset_password/tests/test_app.py | 155 ++-- .../reset_password/views/reset_password.py | 2 +- src/eduid/webapp/security/helpers.py | 2 +- src/eduid/webapp/security/schemas.py | 2 +- src/eduid/webapp/security/tests/test_app.py | 66 +- .../security/tests/test_change_password.py | 64 +- .../webapp/security/tests/test_webauthn.py | 59 +- src/eduid/webapp/security/views/security.py | 5 +- src/eduid/webapp/signup/helpers.py | 2 +- src/eduid/webapp/signup/schemas.py | 10 +- src/eduid/webapp/signup/tests/test_app.py | 169 +++-- src/eduid/webapp/signup/views.py | 16 +- src/eduid/webapp/support/tests/test_app.py | 13 +- src/eduid/webapp/support/views.py | 2 +- src/eduid/webapp/svipe_id/tests/test_app.py | 39 +- src/eduid/workers/am/ams/__init__.py | 4 +- src/eduid/workers/am/fetcher_registry.py | 2 +- src/eduid/workers/am/tasks.py | 4 +- src/eduid/workers/am/testing.py | 12 +- src/eduid/workers/am/tests/test_am.py | 88 ++- src/eduid/workers/am/tests/test_index.py | 8 +- .../am/tests/test_proofing_fetchers.py | 137 +++- src/eduid/workers/am/tests/test_signup.py | 19 +- src/eduid/workers/am/tests/test_tasks.py | 66 +- src/eduid/workers/amapi/context_request.py | 4 +- src/eduid/workers/amapi/middleware.py | 2 +- src/eduid/workers/amapi/routers/users.py | 12 +- .../workers/amapi/routers/utils/status.py | 2 +- src/eduid/workers/amapi/testing.py | 2 +- .../workers/amapi/tests/test_middleware.py | 8 +- src/eduid/workers/amapi/tests/test_status.py | 4 +- src/eduid/workers/amapi/tests/test_user.py | 34 +- src/eduid/workers/job_runner/app.py | 4 +- src/eduid/workers/job_runner/jobs/skv.py | 6 +- src/eduid/workers/job_runner/scheduler.py | 2 +- src/eduid/workers/job_runner/status.py | 7 +- .../job_runner/tests/test_user_cleaner.py | 10 +- src/eduid/workers/lookup_mobile/__init__.py | 2 +- .../development/development_search_result.py | 2 +- .../development/nin_mobile_db.py | 4 +- src/eduid/workers/lookup_mobile/tasks.py | 2 +- .../lookup_mobile/test/test_decorators.py | 8 +- .../workers/lookup_mobile/test/test_lookup.py | 4 +- .../workers/lookup_mobile/test/test_tasks.py | 4 +- src/eduid/workers/msg/cache.py | 8 +- src/eduid/workers/msg/tasks.py | 2 +- .../workers/msg/tests/test_decorators.py | 14 +- src/eduid/workers/msg/tests/test_mongo.py | 8 +- .../workers/msg/tests/test_postaladdress.py | 18 +- src/eduid/workers/msg/tests/test_tasks.py | 12 +- src/eduid/workers/msg/tests/test_utils.py | 6 +- 229 files changed, 2844 insertions(+), 2147 deletions(-) diff --git a/ruff.toml b/ruff.toml index 3f1bfa947..ccff4b6c1 100644 --- a/ruff.toml +++ b/ruff.toml @@ -3,6 +3,6 @@ line-length = 120 target-version = "py310" [lint] -select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN0", ] +select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN0", "ANN201"] ignore = ["E501"] diff --git a/src/eduid/common/config/parsers/decorators.py b/src/eduid/common/config/parsers/decorators.py index d948e1d1c..137e9d618 100644 --- a/src/eduid/common/config/parsers/decorators.py +++ b/src/eduid/common/config/parsers/decorators.py @@ -11,7 +11,7 @@ from eduid.common.config.parsers.exceptions import SecretKeyException -def decrypt(f: Callable): +def decrypt(f: Callable) -> Callable: @wraps(f) def decrypt_decorator(*args: Any, **kwargs: Any): config_dict = f(*args, **kwargs) @@ -83,7 +83,7 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]: return new_config_dict -def interpolate(f: Callable): +def interpolate(f: Callable) -> Callable: @wraps(f) def interpolation_decorator(*args: Any, **kwargs: Any): config_dict = f(*args, **kwargs) diff --git a/src/eduid/common/decorators.py b/src/eduid/common/decorators.py index 9f49bcd60..8fa4311ec 100644 --- a/src/eduid/common/decorators.py +++ b/src/eduid/common/decorators.py @@ -6,7 +6,7 @@ # https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488 -def deprecated(reason: str | type | Callable): +def deprecated(reason: str | type | Callable) -> Callable: """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted diff --git a/src/eduid/common/fastapi/context_request.py b/src/eduid/common/fastapi/context_request.py index 58beee775..622b0ac08 100644 --- a/src/eduid/common/fastapi/context_request.py +++ b/src/eduid/common/fastapi/context_request.py @@ -8,7 +8,7 @@ @dataclass class Context: - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -18,7 +18,7 @@ def __init__(self, context_class: type[Context], *args: Any, **kwargs: Any): self.contextClass = context_class @property - def context(self): + def context(self) -> Context: try: return self.state.context except AttributeError: diff --git a/src/eduid/common/models/jose_models.py b/src/eduid/common/models/jose_models.py index 7d2ecd67a..9e323e161 100644 --- a/src/eduid/common/models/jose_models.py +++ b/src/eduid/common/models/jose_models.py @@ -1,5 +1,6 @@ import datetime from enum import Enum +from typing import Any from pydantic import AnyUrl, BaseModel, Field @@ -115,7 +116,7 @@ class RegisteredClaims(BaseModel): iat: datetime.datetime = Field(default_factory=utc_now) # Issued At jti: str | None = None # JWT ID - def to_rfc7519(self): + def to_rfc7519(self) -> dict[str, Any]: d = self.dict(exclude_none=True) if self.exp: d["exp"] = int((self.iat + self.exp).timestamp()) diff --git a/src/eduid/common/models/scim_base.py b/src/eduid/common/models/scim_base.py index 4ec826b6f..5ebb757d0 100644 --- a/src/eduid/common/models/scim_base.py +++ b/src/eduid/common/models/scim_base.py @@ -104,12 +104,12 @@ class SubResource(EduidBaseModel): display: str @property - def is_user(self): - return self.ref and "/Users/" in self.ref + def is_user(self) -> bool: + return self.ref is not None and "/Users/" in self.ref @property - def is_group(self): - return self.ref and "/Groups/" in self.ref + def is_group(self) -> bool: + return self.ref is not None and "/Groups/" in self.ref @classmethod def from_mapping(cls, data: Any): diff --git a/src/eduid/common/rpc/lookup_mobile_relay.py b/src/eduid/common/rpc/lookup_mobile_relay.py index 302162d36..7633aa863 100644 --- a/src/eduid/common/rpc/lookup_mobile_relay.py +++ b/src/eduid/common/rpc/lookup_mobile_relay.py @@ -1,3 +1,5 @@ +from typing import Any + import eduid.workers.lookup_mobile __author__ = "mathiashedstrom" @@ -26,7 +28,7 @@ def find_nin_by_mobile(self, mobile_number: str) -> str | None: raise LookupMobileTaskFailed(f"find_nin_by_mobile task failed: {e}") @deprecated("This task seems unused") - def find_mobiles_by_nin(self, nin: str): + def find_mobiles_by_nin(self, nin: str) -> Any: try: result = self._find_mobiles_by_NIN.delay(nin) result = result.get(timeout=10) # Lower timeout than standard gunicorn worker timeout (25) diff --git a/src/eduid/graphdb/db.py b/src/eduid/graphdb/db.py index 35c4d5295..91b20df53 100644 --- a/src/eduid/graphdb/db.py +++ b/src/eduid/graphdb/db.py @@ -85,7 +85,7 @@ def __repr__(self) -> str: return f"" @property - def db(self): + def db(self) -> Neo4jDB: return self._db def db_setup(self) -> None: diff --git a/src/eduid/graphdb/groupdb/db.py b/src/eduid/graphdb/groupdb/db.py index cbcb95000..bdfad4ff2 100644 --- a/src/eduid/graphdb/groupdb/db.py +++ b/src/eduid/graphdb/groupdb/db.py @@ -73,7 +73,7 @@ def db_setup(self) -> None: logger.info(f"{self} setup done.") @property - def scope(self): + def scope(self) -> str: return self._scope def _create_or_update_group(self, tx: Transaction, group: Group) -> Group: @@ -288,7 +288,7 @@ def remove_group(self, identifier: str) -> None: with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session: session.run(q, scope=self.scope, identifier=identifier) - def get_groups_by_property(self, key: str, value: str, skip: int = 0, limit: int = 100): + def get_groups_by_property(self, key: str, value: str, skip: int = 0, limit: int = 100) -> list[Group]: res: list[Group] = [] q = f""" MATCH (g: Group {{scope: $scope}}) @@ -377,7 +377,9 @@ def group_exists(self, identifier: str) -> bool: RETURN count(*) as exists LIMIT 1 """ with self.db.driver.session(default_access_mode=READ_ACCESS) as session: - ret = session.run(q, scope=self.scope, identifier=identifier).single()["exists"] + single_value = session.run(q, scope=self.scope, identifier=identifier).single() + assert single_value is not None + ret = single_value["exists"] return bool(ret) def save(self, group: Group) -> Group: diff --git a/src/eduid/graphdb/testing.py b/src/eduid/graphdb/testing.py index 4576b1b5c..ba35476a6 100644 --- a/src/eduid/graphdb/testing.py +++ b/src/eduid/graphdb/testing.py @@ -89,19 +89,19 @@ def conn(self) -> Neo4jDB: return self._conn @property - def host(self): + def host(self) -> str: return self._host @property - def http_port(self): + def http_port(self) -> int: return self._http_port @property - def https_port(self): + def https_port(self) -> int: return self._https_port @property - def bolt_port(self): + def bolt_port(self) -> int: return self._bolt_port def purge_db(self) -> None: diff --git a/src/eduid/graphdb/tests/test_group.py b/src/eduid/graphdb/tests/test_group.py index 2fa36bf17..13c65bb00 100644 --- a/src/eduid/graphdb/tests/test_group.py +++ b/src/eduid/graphdb/tests/test_group.py @@ -1,51 +1,66 @@ +from typing import TypedDict from unittest import TestCase +from typing_extensions import NotRequired + from eduid.graphdb.groupdb import Group, User __author__ = "lundberg" +class GroupData(TypedDict): + identifier: str + display_name: str + members: NotRequired[set[Group | User]] + owners: NotRequired[set[Group | User]] + + +class UserData(TypedDict): + identifier: str + display_name: str + + class TestGroup(TestCase): def setUp(self) -> None: - self.group1: dict[str, str | list] = { + self.group1: GroupData = { "identifier": "test1", "display_name": "Test Group 1", } - self.group2: dict[str, str | list] = { + self.group2: GroupData = { "identifier": "test2", "display_name": "Test Group 2", } - self.user1: dict[str, str] = {"identifier": "user1", "display_name": "Test Testsson"} - self.user2: dict[str, str] = {"identifier": "user2", "display_name": "Namn Namnsson"} + self.user1: UserData = {"identifier": "user1", "display_name": "Test Testsson"} + self.user2: UserData = {"identifier": "user2", "display_name": "Namn Namnsson"} - def test_init_group(self): + def test_init_group(self) -> None: group = Group(**self.group1) assert self.group1["identifier"] == group.identifier assert self.group1["display_name"] == group.display_name - def test_init_group_with_members(self): + def test_init_group_with_members(self) -> None: user = User(**self.user1) - self.group1["members"] = [user] + self.group1["members"] = set([user]) group = Group(**self.group1) assert user in group.members assert user in group.member_users assert 0 == len(group.member_groups) - group.members.append(Group(**self.group2)) + group.members.add(Group(**self.group2)) assert 1 == len(group.member_groups) - group.members.append(User(**self.user2)) + group.members.add(User(**self.user2)) assert 3 == len(group.members) - def test_init_group_with_owner_and_member(self): + def test_init_group_with_owner_and_member(self) -> None: user = User(**self.user1) owner = User(**self.user2) - self.group1["members"] = [user] - self.group1["owners"] = [owner] + self.group1["members"] = set([user]) + self.group1["owners"] = set([owner]) group = Group(**self.group1) assert user in group.members assert user in group.member_users assert owner in group.owners - def test_get_users_and_groups(self): + def test_get_users_and_groups(self) -> None: member1 = User(**self.user1) member2 = User(**self.user2) member3 = Group(**self.group2) @@ -53,8 +68,8 @@ def test_get_users_and_groups(self): owner2 = User(**self.user2) owner3 = Group(**self.group2) - self.group1["members"] = [member1, member2, member3] - self.group1["owners"] = [owner1, owner2, owner3] + self.group1["members"] = set([member1, member2, member3]) + self.group1["owners"] = set([owner1, owner2, owner3]) group = Group(**self.group1) assert owner2 == group.get_owner_user(identifier=owner2.identifier) diff --git a/src/eduid/graphdb/tests/test_groupdb.py b/src/eduid/graphdb/tests/test_groupdb.py index 50602fef0..aa6f20c64 100644 --- a/src/eduid/graphdb/tests/test_groupdb.py +++ b/src/eduid/graphdb/tests/test_groupdb.py @@ -44,16 +44,17 @@ def _assert_user(expected: User, testing: User): assert expected.identifier == testing.identifier assert expected.display_name == testing.display_name - def test_create_group(self): + def test_create_group(self) -> None: group = Group.from_mapping(self.group1) post_save_group = self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") self._assert_group(group, post_save_group) get_group = self.group_db.get_group(identifier="test1") + assert isinstance(get_group, Group) self._assert_group(group, get_group) - def test_update_group(self): + def test_update_group(self) -> None: group = Group.from_mapping(self.group1) post_save_group = self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") @@ -65,7 +66,7 @@ def test_update_group(self): assert 1 == self.group_db.db.count_nodes(label="Group") self._assert_group(group, post_save_group2, modified=True) - def test_get_group_by_property(self): + def test_get_group_by_property(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) @@ -73,25 +74,25 @@ def test_get_group_by_property(self): assert 1 == len(post_get_group) self._assert_group(group, post_get_group[0]) - def test_get_non_existing_group(self): + def test_get_non_existing_group(self) -> None: group = self.group_db.get_group(identifier="test1") self.assertIsNone(group) - def test_group_exists(self): + def test_group_exists(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) self.assertTrue(self.group_db.group_exists(identifier=group.identifier)) self.assertFalse(self.group_db.group_exists(identifier="wrong-identifier")) - def test_get_groups(self): + def test_get_groups(self) -> None: self.group_db.save(Group.from_mapping(self.group1)) self.group_db.save(Group.from_mapping(self.group2)) groups = self.group_db.get_groups() assert 2 == len(groups) - def test_save_with_wrong_group_version(self): + def test_save_with_wrong_group_version(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) group = replace(group, display_name="Another display name") @@ -100,7 +101,7 @@ def test_save_with_wrong_group_version(self): self.group_db.save(group) assert 1 == self.group_db.db.count_nodes(label="Group") - def test_create_group_with_user_member(self): + def test_create_group_with_user_member(self) -> None: group = Group.from_mapping(self.group1) user = User.from_mapping(self.user1) group.members.add(user) @@ -111,10 +112,11 @@ def test_create_group_with_user_member(self): assert 1 == self.group_db.db.count_nodes(label="User") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group post_save_user = post_save_group.member_users[0] self._assert_user(user, post_save_user) - def test_create_group_with_group_member(self): + def test_create_group_with_group_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -124,10 +126,11 @@ def test_create_group_with_group_member(self): assert 2 == self.group_db.db.count_nodes(label="Group") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group post_save_member_group = post_save_group.member_groups[0] self._assert_group(member_group, post_save_member_group) - def test_create_group_with_group_member_and_user_owner(self): + def test_create_group_with_group_member_and_user_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -143,6 +146,7 @@ def test_create_group_with_group_member_and_user_owner(self): assert 2 == self.group_db.db.count_nodes(label="Group") post_save_group = self.group_db.get_group(identifier="test1") + assert post_save_group post_save_member_group = post_save_group.member_groups[0] self._assert_group(member_group, post_save_member_group) @@ -150,9 +154,10 @@ def test_create_group_with_group_member_and_user_owner(self): self._assert_user(member_user, post_save_user) post_save_owner = post_save_group.owners.pop() + assert isinstance(post_save_owner, User) self._assert_user(owner, post_save_owner) - def test_remove_group(self): + def test_remove_group(self) -> None: group = Group.from_mapping(self.group1) self.group_db.save(group) assert self.group_db.group_exists(group.identifier) is True @@ -160,7 +165,7 @@ def test_remove_group(self): self.group_db.remove_group(group.identifier) assert self.group_db.group_exists(group.identifier) is False - def test_get_groups_for_user_member(self): + def test_get_groups_for_user_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -180,11 +185,17 @@ def test_get_groups_for_user_member(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(group.owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 1 == len(groups[0].members) - self._assert_user(member_user, groups[0].members.pop()) + groups_0_members = groups[0].members.pop() + assert isinstance(groups_0_members, User) + self._assert_user(member_user, groups_0_members) - def test_get_groups_for_user_member_2(self): + def test_get_groups_for_user_member_2(self) -> None: group1 = Group.from_mapping(self.group1) group2 = Group.from_mapping(self.group2) member_user = User.from_mapping(self.user1) @@ -202,7 +213,7 @@ def test_get_groups_for_user_member_2(self): assert sorted([group1.display_name, group2.display_name]) == sorted([x.display_name for x in groups]) assert groups[0].created_ts is not None - def test_get_groups_for_group_member(self): + def test_get_groups_for_group_member(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -222,11 +233,17 @@ def test_get_groups_for_group_member(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(group.owners) - self._assert_user(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 1 == len(groups[0].members) - self._assert_group(member_group, groups[0].members.pop()) + groups_0_members = groups[0].members.pop() + assert isinstance(groups_0_members, Group) + self._assert_group(member_group, groups_0_members) - def test_get_groups_for_user_owner(self): + def test_get_groups_for_user_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -246,10 +263,14 @@ def test_get_groups_for_user_owner(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(groups[0].owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, User) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, User) + self._assert_user(group_owners, groups_0_owners) assert 2 == len(groups[0].members) - def test_get_groups_for_user_owner_2(self): + def test_get_groups_for_user_owner_2(self) -> None: group1 = Group.from_mapping(self.group1) group2 = Group.from_mapping(self.group2) owner_user = User.from_mapping(self.user1) @@ -272,7 +293,7 @@ def test_get_groups_for_user_owner_2(self): assert 1 == len(groups[0].members) assert 1 == len(groups[1].members) - def test_get_groups_for_group_owner(self): + def test_get_groups_for_group_owner(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -292,10 +313,14 @@ def test_get_groups_for_group_owner(self): assert 1 == len(groups) self._assert_group(group, groups[0]) assert 1 == len(groups[0].owners) - self._assert_group(group.owners.pop(), groups[0].owners.pop()) + group_owners = group.owners.pop() + assert isinstance(group_owners, Group) + groups_0_owners = groups[0].owners.pop() + assert isinstance(groups_0_owners, Group) + self._assert_group(group_owners, groups_0_owners) assert 2 == len(groups[0].members) - def test_get_groups_and_users_by_role(self): + def test_get_groups_and_users_by_role(self) -> None: group = Group.from_mapping(self.group1) member_group = Group.from_mapping(self.group2) group.members.add(member_group) @@ -324,7 +349,7 @@ def test_get_groups_and_users_by_role(self): assert owner_user.identifier in [owner.identifier for owner in owners] assert 2 == len(owners) - def test_remove_user_from_group(self): + def test_remove_user_from_group(self) -> None: group = Group.from_mapping(self.group1) member_user1 = User.from_mapping(self.user1) member_user2 = User.from_mapping(self.user2) @@ -347,10 +372,11 @@ def test_remove_user_from_group(self): assert post_remove_group.has_member(member_user2.identifier) is True get_group = self.group_db.get_group(identifier="test1") + assert get_group assert get_group.has_member(member_user1.identifier) is False assert get_group.has_member(member_user2.identifier) is True - def test_remove_group_from_group(self): + def test_remove_group_from_group(self) -> None: group = Group.from_mapping(self.group1) member_user1 = User.from_mapping(self.user1) member_group1 = Group.from_mapping(self.group2) @@ -373,5 +399,6 @@ def test_remove_group_from_group(self): assert post_remove_group.has_member(member_user1.identifier) is True get_group = self.group_db.get_group(identifier=group.identifier) + assert get_group assert get_group.has_member(member_group1.identifier) is False assert get_group.has_member(member_user1.identifier) is True diff --git a/src/eduid/maccapi/helpers.py b/src/eduid/maccapi/helpers.py index db955ce9c..de95e291b 100644 --- a/src/eduid/maccapi/helpers.py +++ b/src/eduid/maccapi/helpers.py @@ -20,7 +20,7 @@ class UnableToAddPassword(Exception): pass -def list_users(context: Context, data_owner: str): +def list_users(context: Context, data_owner: str) -> list[ManagedAccount]: managed_accounts: list[ManagedAccount] = context.db.get_users(data_owner=data_owner) context.logger.info(f"Listing {managed_accounts.__len__()} users") return managed_accounts @@ -44,7 +44,7 @@ def add_password(context: Context, managed_account: ManagedAccount, password: st return True -def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: str): +def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: str) -> bool: vccs = context.vccs_client revoke_factors = [] @@ -66,7 +66,7 @@ def revoke_passwords(context: Context, managed_account: ManagedAccount, reason: return True -def save_and_sync_user(context: Context, managed_account: ManagedAccount): +def save_and_sync_user(context: Context, managed_account: ManagedAccount) -> None: context.logger.debug(f"Saving and syncing user {managed_account}") result = context.db.save(managed_account) context.logger.debug(f"Saved user {managed_account} with result {result}") @@ -123,7 +123,7 @@ def deactivate_user(context: Context, eppn: str, data_owner: str) -> ManagedAcco return managed_account -def replace_password(context: Context, eppn: str, new_password: str): +def replace_password(context: Context, eppn: str, new_password: str) -> None: managed_account: ManagedAccount = context.db.get_user_by_eppn(eppn) if managed_account is None: raise UserDoesNotExist(f"User {eppn} not found") @@ -144,7 +144,7 @@ def get_user(context: Context, eppn: str, data_owner: str) -> ManagedAccount: return managed_account -def add_api_event(context: Context, eppn: str, action: str, action_by: str, data_owner: str): +def add_api_event(context: Context, eppn: str, action: str, action_by: str, data_owner: str) -> None: expiration: datetime = utc_now() + timedelta(days=context.config.log_retention_days) log_element = ManagedAccountLogElement( eppn=eppn, diff --git a/src/eduid/maccapi/middleware.py b/src/eduid/maccapi/middleware.py index cb896897e..ebdb296e7 100644 --- a/src/eduid/maccapi/middleware.py +++ b/src/eduid/maccapi/middleware.py @@ -91,6 +91,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - self.context.logger.error(f"Data owner {repr(data_owner)} not configured") return return_error_response(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown data_owner") + assert isinstance(request.context, MaccAPIContext) request.context.data_owner = data_owner request.context.manager_eppn = token.saml_eppn diff --git a/src/eduid/maccapi/routers/status.py b/src/eduid/maccapi/routers/status.py index 2b49c3580..4d01cfd2c 100644 --- a/src/eduid/maccapi/routers/status.py +++ b/src/eduid/maccapi/routers/status.py @@ -24,7 +24,7 @@ class StatusResponse(BaseModel): reason: str -def check_mongo(request: ContextRequest): +def check_mongo(request: ContextRequest) -> bool | None: db = request.app.context.db try: db.is_healthy() diff --git a/src/eduid/maccapi/routers/users.py b/src/eduid/maccapi/routers/users.py index 5642a9dff..9e7dfb037 100644 --- a/src/eduid/maccapi/routers/users.py +++ b/src/eduid/maccapi/routers/users.py @@ -2,7 +2,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.utils import generate_password -from eduid.maccapi.context_request import MaccAPIRoute +from eduid.maccapi.context_request import MaccAPIContext, MaccAPIRoute from eduid.maccapi.helpers import ( UnableToAddPassword, add_api_event, @@ -35,6 +35,8 @@ async def get_users(request: ContextRequest) -> UserListResponse: return all users that the calling user has access to in current context """ + assert isinstance(request.context, MaccAPIContext) + assert request.context.data_owner is not None manages_accounts = list_users(context=request.app.context, data_owner=request.context.data_owner) users = [ApiUser(eppn=user.eppn, given_name=user.given_name, surname=user.surname) for user in manages_accounts] @@ -56,6 +58,9 @@ async def add_user( password = generate_password() presentable_password = make_presentable_password(password) + assert isinstance(request.context, MaccAPIContext) + assert request.context.data_owner is not None + assert request.context.manager_eppn is not None managed_account: ManagedAccount = create_and_sync_user( context=request.app.context, data_owner=request.context.data_owner, @@ -101,6 +106,9 @@ async def remove_user( request.app.context.logger.debug(f"remove_user: {remove_request}") try: + assert isinstance(request.context, MaccAPIContext) + assert request.context.data_owner is not None + assert request.context.manager_eppn is not None managed_account: ManagedAccount = deactivate_user( context=request.app.context, eppn=remove_request.eppn, data_owner=request.context.data_owner ) @@ -145,6 +153,9 @@ async def reset_password( new_password = generate_password() presentable_password = make_presentable_password(new_password) try: + assert isinstance(request.context, MaccAPIContext) + assert request.context.data_owner is not None + assert request.context.manager_eppn is not None managed_account = get_user(context=request.app.context, eppn=eppn, data_owner=request.context.data_owner) replace_password(context=request.app.context, eppn=eppn, new_password=new_password) diff --git a/src/eduid/maccapi/tests/test_maccapi.py b/src/eduid/maccapi/tests/test_maccapi.py index 79c769938..5f8322bff 100644 --- a/src/eduid/maccapi/tests/test_maccapi.py +++ b/src/eduid/maccapi/tests/test_maccapi.py @@ -32,7 +32,7 @@ def _make_bearer_token(self, claims: Mapping[str, Any]) -> str: def _is_presentable_format(self, password: str) -> bool: return len(password) == 14 and password[4] == " " and password[9] == " " - def test_create_user(self): + def test_create_user(self) -> None: domain = "eduid.se" claims = { "saml_eppn": "test@eduid.se", @@ -60,7 +60,7 @@ def test_create_user(self): assert payload["user"]["eppn"] is not None assert payload["user"]["password"] is not None - def test_create_multiple_users(self): + def test_create_multiple_users(self) -> None: claims = { "saml_eppn": "test@eduid.se", "version": 1, @@ -95,7 +95,7 @@ def test_create_multiple_users(self): assert payload["status"] == "success" assert len(payload["users"]) == 2 - def test_remove_user(self): + def test_remove_user(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -117,7 +117,7 @@ def test_remove_user(self): assert payload["user"]["given_name"] == self.user1["given_name"] assert payload["user"]["surname"] == self.user1["surname"] - def test_reset_password(self): + def test_reset_password(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -145,7 +145,7 @@ def test_reset_password(self): new_password = payload["user"]["password"] assert self._is_presentable_format(new_password) - def test_remove_error(self): + def test_remove_error(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers @@ -154,7 +154,7 @@ def test_remove_error(self): response = self.client.post(url="/Users/remove", json={"eppn": "made_up"}, headers=headers) assert response.status_code == 422 - def test_reset_error(self): + def test_reset_error(self) -> None: token = self._make_bearer_token(claims=self.claims) headers = self.headers diff --git a/src/eduid/queue/db/worker.py b/src/eduid/queue/db/worker.py index 8fd083e7e..673ba1131 100644 --- a/src/eduid/queue/db/worker.py +++ b/src/eduid/queue/db/worker.py @@ -33,7 +33,7 @@ async def create(cls, db_uri: str, collection: str, db_name: str = "eduid_queue" await instance.setup_indexes(indexes) return instance - def parse_queue_item(self, doc: Mapping, parse_payload: bool = True): + def parse_queue_item(self, doc: Mapping, parse_payload: bool = True) -> QueueItem: item = QueueItem.from_dict(doc) if parse_payload is False: # Return the item with the generic RawPayload diff --git a/src/eduid/queue/testing.py b/src/eduid/queue/testing.py index 7c68c66eb..663abfe8c 100644 --- a/src/eduid/queue/testing.py +++ b/src/eduid/queue/testing.py @@ -87,7 +87,7 @@ def get_instance( return cast(MongoTemporaryInstanceReplicaSet, super().get_instance(max_retry_seconds=max_retry_seconds)) @property - def uri(self): + def uri(self) -> str: return f"mongodb://localhost:{self.port}" @@ -209,7 +209,7 @@ async def _assert_item_gets_processed(self, queue_item: QueueItem, retry: bool = class IsolatedWorkerDBMixin(MixinBase): # override run so we can mock cache of database clients - async def run(self): + async def run(self) -> None: # Init db in the correct loop # Make sure the isolated test cases get to create their own mongodb clients with patch("eduid.userdb.db.async_db.AsyncClientCache._clients", {}): diff --git a/src/eduid/queue/tests/test_client.py b/src/eduid/queue/tests/test_client.py index e1ed03d8b..6e5d52e36 100644 --- a/src/eduid/queue/tests/test_client.py +++ b/src/eduid/queue/tests/test_client.py @@ -12,7 +12,7 @@ class TestClient(TestCase): - def test_queue_item(self): + def test_queue_item(self) -> None: expires_at = utc_now() + timedelta(days=180) discard_at = expires_at + timedelta(days=7) sender_info = SenderInfo(hostname="testhost", node_id="userdb@testhost") @@ -46,7 +46,7 @@ def _create_queue_item(self, payload: Payload): payload=payload, ) - def test_eduid_invite_mail(self): + def test_eduid_invite_mail(self) -> None: payload = EduidInviteEmail( email="mail@example.com", reference="ref_id", @@ -60,7 +60,7 @@ def test_eduid_invite_mail(self): assert normalised_data(item.to_dict()) == normalised_data(loaded_message_dict) assert normalised_data(payload.to_dict()) == normalised_data(item.payload.to_dict()) - def test_eduid_signup_mail(self): + def test_eduid_signup_mail(self) -> None: payload = EduidSignupEmail( email="mail@example.com", reference="ref_id", @@ -75,7 +75,7 @@ def test_eduid_signup_mail(self): class TestMessageDB(EduidQueueTestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.messagedb = MessageDB(self.mongo_uri) self.messagedb.register_handler(TestPayload) @@ -86,7 +86,7 @@ def setUp(self): self.discard_at = self.expires_at + timedelta(days=7) self.sender_info = SenderInfo(hostname="testhost", node_id="userdb@testhost") - def tearDown(self): + def tearDown(self) -> None: super().tearDown() self.messagedb._drop_whole_collection() @@ -100,33 +100,36 @@ def _create_queue_item(self, payload: Payload): payload=payload, ) - def test_save_load(self): + def test_save_load(self) -> None: payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) self.messagedb.save(item) assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert loaded_item.payload_type == payload.get_type() assert isinstance(loaded_item.payload, TestPayload) is True assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) - def test_save_load_raw_payload(self): + def test_save_load_raw_payload(self) -> None: payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) self.messagedb.save(item) assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert loaded_item.payload_type == payload.get_type() assert isinstance(loaded_item.payload, TestPayload) is True raw_loaded_item = self.messagedb.get_item_by_id(item.item_id, parse_payload=False) + assert raw_loaded_item assert raw_loaded_item.payload_type == payload.get_type() assert isinstance(raw_loaded_item.payload, RawPayload) is True assert normalised_data(item.payload.to_dict()) == normalised_data(raw_loaded_item.payload.to_dict()) - def test_save_load_eduid_email_invite(self): + def test_save_load_eduid_email_invite(self) -> None: payload = EduidInviteEmail( email="mail@example.com", reference="ref_id", @@ -140,10 +143,11 @@ def test_save_load_eduid_email_invite(self): assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) assert normalised_data(item.payload.to_dict()), normalised_data(loaded_item.payload.to_dict()) - def test_save_load_eduid_email_signup(self): + def test_save_load_eduid_email_signup(self) -> None: payload = EduidSignupEmail( email="mail@example.com", reference="ref_id", @@ -156,12 +160,13 @@ def test_save_load_eduid_email_signup(self): assert 1 == self.messagedb.db_count() loaded_item = self.messagedb.get_item_by_id(item.item_id) + assert loaded_item assert normalised_data(item.to_dict()) == normalised_data(loaded_item.to_dict()) assert normalised_data(item.payload.to_dict()) == normalised_data(loaded_item.payload.to_dict()) @skip("It takes mongo a couple of seconds to actually remove the document, skip for now.") # TODO: Investigate if it is possible to force a expire check in mongodb - def test_auto_discard(self): + def test_auto_discard(self) -> None: self.discard_at = utc_now() - timedelta(seconds=-10) payload = TestPayload(message="this is a test payload") item = self._create_queue_item(payload) diff --git a/src/eduid/queue/tests/test_mail_worker.py b/src/eduid/queue/tests/test_mail_worker.py index 635751f7c..184d972a9 100644 --- a/src/eduid/queue/tests/test_mail_worker.py +++ b/src/eduid/queue/tests/test_mail_worker.py @@ -66,7 +66,7 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: await super().asyncTearDown() - async def test_eduid_signup_mail_from_stream(self): + async def test_eduid_signup_mail_from_stream(self) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -82,7 +82,7 @@ async def test_eduid_signup_mail_from_stream(self): await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail: MagicMock): + async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail: MagicMock) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -99,7 +99,7 @@ async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_send await self._assert_item_gets_processed(queue_item) @patch("aiosmtplib.SMTP.sendmail") - async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: MagicMock): + async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: MagicMock) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -118,7 +118,7 @@ async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: Ma self.client_db.save(queue_item) await self._assert_item_gets_processed(queue_item, retry=True) - async def test_register_mail_translations(self): + async def test_register_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidSignupEmail( email="noone@example.com", @@ -143,7 +143,7 @@ async def test_register_mail_translations(self): assert "Subject: eduID-registrering" in msg_string assert "Du har registrerat noone@example.com som e-postadress" in msg_string - async def test_reset_password_mail_translations(self): + async def test_reset_password_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidResetPasswordEmail( email="noone@example.com", @@ -171,7 +171,7 @@ async def test_reset_password_mail_translations(self): assert "Du har bett om att byta" in msg_string assert "giltig i 2 timmar." in msg_string - async def test_verification_mail_translations(self): + async def test_verification_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidVerificationEmail( email="noone@example.com", @@ -197,7 +197,7 @@ async def test_verification_mail_translations(self): assert "Du har nyligen lagt till den" in msg_string assert "Skriv in koden nedan" in msg_string - async def test_termination_mail_translations(self): + async def test_termination_mail_translations(self) -> None: for lang in ["en", "sv"]: payload = EduidTerminationEmail( email="noone@example.com", diff --git a/src/eduid/queue/tests/test_worker.py b/src/eduid/queue/tests/test_worker.py index 0ec1aade5..23c2423bc 100644 --- a/src/eduid/queue/tests/test_worker.py +++ b/src/eduid/queue/tests/test_worker.py @@ -62,7 +62,7 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: await super().asyncTearDown() - async def test_worker_item_from_stream(self): + async def test_worker_item_from_stream(self) -> None: """ Test that saved queue items are handled by the handle_new_item method """ @@ -74,7 +74,7 @@ async def test_worker_item_from_stream(self): self.client_db.save(queue_item) await self._assert_item_gets_processed(queue_item) - async def test_worker_expired_item(self): + async def test_worker_expired_item(self) -> None: """ Test that expired queue items are handled by the handle_expired_item method """ diff --git a/src/eduid/queue/workers/base.py b/src/eduid/queue/workers/base.py index 93415f729..734e5e4fe 100644 --- a/src/eduid/queue/workers/base.py +++ b/src/eduid/queue/workers/base.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def cancel_task(signame: str, task: Task): +def cancel_task(signame: str, task: Task) -> None: logger.info(f"got signal {signame}: exit") task.cancel() @@ -47,7 +47,7 @@ def add_task(tasks: set[Task], task: Task) -> set[Task]: tasks.add(task) return tasks - async def run(self): + async def run(self) -> None: # Init db in the correct loop self.db = await AsyncQueueDB.create(db_uri=self.config.mongo_uri, collection=self.config.mongo_collection) # Register payloads to handle @@ -64,7 +64,7 @@ async def run(self): logger.info(f"Running: {main_task.get_name()}") await main_task - async def run_subtasks(self): + async def run_subtasks(self) -> None: logger.info(f"Initiating event stream for: {self.db}") watch_collection_task = asyncio.create_task( self.watch_collection(), name=f"Watch collection {self.config.mongo_collection}" diff --git a/src/eduid/queue/workers/mail.py b/src/eduid/queue/workers/mail.py index d92eb9bd2..40fc77705 100644 --- a/src/eduid/queue/workers/mail.py +++ b/src/eduid/queue/workers/mail.py @@ -46,7 +46,7 @@ def __init__(self, config: QueueWorkerConfig): self._jinja2 = Jinja2Env() @property - async def smtp(self): + async def smtp(self) -> SMTP: if self._smtp is None: logger.debug(f"Creating SMTP client for {self.config.mail_host}:{self.config.mail_port}") validate_certs = self.config.mail_verify_tls @@ -293,7 +293,7 @@ def init_mail_worker(name: str = "mail_worker", test_config: Mapping[str, Any] | return MailQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_mail_worker() if worker.smtp is None: # fail fast if we can't connect to the SMTP server diff --git a/src/eduid/queue/workers/scim_event.py b/src/eduid/queue/workers/scim_event.py index 31d7fbf4c..b406de80a 100644 --- a/src/eduid/queue/workers/scim_event.py +++ b/src/eduid/queue/workers/scim_event.py @@ -62,7 +62,7 @@ def init_scim_event_worker( return ScimEventQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_scim_event_worker() exit(asyncio.run(worker.run())) diff --git a/src/eduid/queue/workers/sink.py b/src/eduid/queue/workers/sink.py index 5643b5b5e..4d531377a 100644 --- a/src/eduid/queue/workers/sink.py +++ b/src/eduid/queue/workers/sink.py @@ -91,7 +91,7 @@ def init_sink_worker(name: str = "sink_worker", test_config: Mapping[str, Any] | return SinkQueueWorker(config=config) -def start_worker(): +def start_worker() -> None: worker = init_sink_worker() exit(asyncio.run(worker.run())) diff --git a/src/eduid/satosa/scimapi/serve_static.py b/src/eduid/satosa/scimapi/serve_static.py index dea589e0d..438929e06 100644 --- a/src/eduid/satosa/scimapi/serve_static.py +++ b/src/eduid/satosa/scimapi/serve_static.py @@ -38,7 +38,7 @@ def __init__(self, config: SATOSAConfig, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.locations = config.get("locations", {}) - def register_endpoints(self): + def register_endpoints(self) -> list: url_map = [] for endpoint, path in self.locations.items(): endpoint = endpoint.strip("/") diff --git a/src/eduid/satosa/scimapi/stepup.py b/src/eduid/satosa/scimapi/stepup.py index 4b14fffc4..bb65c745f 100644 --- a/src/eduid/satosa/scimapi/stepup.py +++ b/src/eduid/satosa/scimapi/stepup.py @@ -36,7 +36,7 @@ RequestMicroService, ResponseMicroService, ) -from satosa.response import Response +from satosa.response import Response, SeeOther from satosa.saml_util import make_saml_response try: @@ -625,7 +625,7 @@ def __init__(self, *args: Any, **kwargs: Any): raise StepUpError(f"The configuration for this plugin is not valid: {e}") self.mfa = parsed_config.mfa - def authn_request(self, context: satosa.context.Context, entity_id: str): + def authn_request(self, context: satosa.context.Context, entity_id: str) -> SeeOther | Response: logger.debug(f"Processing AuthnRequest with entity id {repr(entity_id)}") if self.mfa and AuthnContext.sp_wants_mfa(context=context): diff --git a/src/eduid/scimapi/middleware.py b/src/eduid/scimapi/middleware.py index 05139729f..6689dda2b 100644 --- a/src/eduid/scimapi/middleware.py +++ b/src/eduid/scimapi/middleware.py @@ -29,7 +29,7 @@ # Hack to be able to get request body both now and later # https://github.com/encode/starlette/issues/495#issuecomment-513138055 -async def set_body(request: Request, body: bytes): +async def set_body(request: Request, body: bytes) -> None: async def receive() -> Message: return {"type": "http.request", "body": body} @@ -81,6 +81,8 @@ def _is_no_auth_path(self, url: URL) -> bool: async def dispatch(self, req: Request, call_next: RequestResponseEndpoint) -> Response: req = self.make_context_request(request=req, context_class=ScimApiContext) + assert isinstance(req.context, ScimApiContext) # please mypy + if self._is_no_auth_path(req.url): return await call_next(req) diff --git a/src/eduid/scimapi/routers/events.py b/src/eduid/scimapi/routers/events.py index 62948a6fd..dad8e1360 100644 --- a/src/eduid/scimapi/routers/events.py +++ b/src/eduid/scimapi/routers/events.py @@ -5,7 +5,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import SCIMResourceType from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.models.event import EventCreateRequest, EventResponse from eduid.scimapi.routers.utils.events import db_event_to_response, get_scim_referenced @@ -31,6 +31,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching event {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.eventdb is not None db_event = req.context.eventdb.get_event_by_scim_id(scim_id) if not db_event: raise NotFound(detail="Event not found") @@ -91,6 +93,10 @@ async def on_post(req: ContextRequest, resp: Response, create_request: EventCrea _timestamp = create_request.nutid_event_v1.timestamp _expires_at = utc_now() + timedelta(days=1) + assert isinstance(req.context, ScimApiContext) + assert req.context.data_owner is not None + assert req.context.eventdb is not None + event = ScimApiEvent( resource=ScimApiEventResource( resource_type=create_request.nutid_event_v1.resource.resource_type, diff --git a/src/eduid/scimapi/routers/groups.py b/src/eduid/scimapi/routers/groups.py index 3bd96a365..471547302 100644 --- a/src/eduid/scimapi/routers/groups.py +++ b/src/eduid/scimapi/routers/groups.py @@ -3,7 +3,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SearchRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.models.group import GroupCreateRequest, GroupResponse, GroupUpdateRequest from eduid.scimapi.routers.utils.events import add_api_event @@ -29,6 +29,8 @@ @groups_router.get("/", response_model=ListResponse) async def on_get_all(req: ContextRequest) -> ListResponse: + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None db_groups = req.context.groupdb.get_groups() resources = [{"id": str(db_group.scim_id), "displayName": db_group.graph.display_name} for db_group in db_groups] return ListResponse(total_results=len(db_groups), resources=resources) @@ -67,6 +69,8 @@ async def on_get_one(req: ContextRequest, resp: Response, scim_id: str) -> Group """ req.app.context.logger.info(f"Fetching group {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None db_group = req.context.groupdb.get_group_by_scim_id(scim_id) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -135,6 +139,8 @@ async def on_put( raise BadRequest(detail="Id mismatch") req.app.context.logger.info(f"Fetching group {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None db_group = req.context.groupdb.get_group_by_scim_id(str(update_request.id)) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -145,6 +151,7 @@ async def on_put( raise BadRequest(detail="Version mismatch") # Check that members exists in their respective db + assert req.context.userdb is not None req.app.context.logger.info("Checking if group and user members exists") for member in update_request.members: if member.is_group: @@ -162,6 +169,7 @@ async def on_put( db_group = req.context.groupdb.get_group_by_scim_id(str(updated_group.scim_id)) assert db_group # please mypy + assert req.context.data_owner is not None if changed: add_api_event( context=req.app.context, @@ -209,12 +217,15 @@ async def on_post(req: ContextRequest, resp: Response, create_request: GroupCrea """ req.app.context.logger.info("Creating group") req.app.context.logger.debug(create_request) + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None created_group = req.context.groupdb.create_group(create_request=create_request) # Load the group from the database to ensure results are consistent with subsequent GETs. # For example, timestamps have higher resolution in created_group than after a load. db_group = req.context.groupdb.get_group_by_scim_id(str(created_group.scim_id)) assert db_group # please mypy + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -237,6 +248,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: GroupCrea ) async def on_delete(req: ContextRequest, scim_id: str) -> None: req.app.context.logger.info(f"Deleting group {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None db_group = req.context.groupdb.get_group_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found group: {db_group}") if not db_group: @@ -248,6 +261,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: res = req.context.groupdb.remove_group(db_group) + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/invites.py b/src/eduid/scimapi/routers/invites.py index 9f375a8cb..ab1ac9db5 100644 --- a/src/eduid/scimapi/routers/invites.py +++ b/src/eduid/scimapi/routers/invites.py @@ -6,7 +6,7 @@ from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SearchRequest from eduid.common.models.scim_invite import InviteCreateRequest, InviteResponse, InviteUpdateRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, ErrorDetail, NotFound from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.routers.utils.invites import ( @@ -40,6 +40,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching invite {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.invitedb is not None db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id) if not db_invite: raise NotFound(detail="Invite not found") @@ -60,6 +62,8 @@ async def on_put( raise BadRequest(detail="Id mismatch") req.app.context.logger.info(f"Updating invite {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.invitedb is not None db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id) if not db_invite: raise NotFound(detail="Invite not found") @@ -78,6 +82,7 @@ async def on_put( db_invite = replace(db_invite, completed=update_request.nutid_invite_v1.completed) invite_changed = True + assert req.context.data_owner is not None if invite_changed: save_invite( req=req, @@ -173,6 +178,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: InviteCre if signup_invite.send_email: send_invite_mail(req, signup_invite) + assert isinstance(req.context, ScimApiContext) + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -190,6 +197,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: InviteCre @invites_router.delete("/{scim_id}", status_code=204, responses={204: {"description": "No Content"}}) async def on_delete(req: ContextRequest, scim_id: str) -> None: req.app.context.logger.info(f"Deleting invite {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.invitedb is not None db_invite = req.context.invitedb.get_invite_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found invite: {db_invite}") @@ -209,6 +218,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: # Remove scim invite res = req.context.invitedb.remove(db_invite) + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/users.py b/src/eduid/scimapi/routers/users.py index 7aac70c58..938799976 100644 --- a/src/eduid/scimapi/routers/users.py +++ b/src/eduid/scimapi/routers/users.py @@ -8,7 +8,7 @@ from eduid.common.models.scim_base import ListResponse, SCIMResourceType, SCIMSchema, SearchRequest from eduid.common.models.scim_user import UserCreateRequest, UserResponse, UserUpdateRequest from eduid.scimapi.api_router import APIRouter -from eduid.scimapi.context_request import ScimApiRoute +from eduid.scimapi.context_request import ScimApiContext, ScimApiRoute from eduid.scimapi.exceptions import BadRequest, Conflict, ErrorDetail, MaxRetriesReached, NotFound from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.routers.utils.users import ( @@ -48,6 +48,8 @@ async def on_get(req: ContextRequest, resp: Response, scim_id: str | None = None if scim_id is None: raise BadRequest(detail="Not implemented") req.app.context.logger.info(f"Fetching user {scim_id}") + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None db_user = req.context.userdb.get_user_by_scim_id(scim_id) if not db_user: raise NotFound(detail="User not found") @@ -64,6 +66,8 @@ async def on_put(req: ContextRequest, resp: Response, update_request: UserUpdate req.app.context.logger.debug(f"{scim_id} != {update_request.id}") raise BadRequest(detail="Id mismatch") + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None db_user = req.context.userdb.get_user_by_scim_id(scim_id) if not db_user: raise NotFound(detail="User not found") @@ -143,6 +147,7 @@ async def on_put(req: ContextRequest, resp: Response, update_request: UserUpdate req.app.context.logger.debug(f"Core changed: {core_changed}, nutid_changed: {nutid_changed}") if core_changed or nutid_changed: + assert req.context.data_owner is not None save_user(req, db_user) add_api_event( context=req.app.context, @@ -238,6 +243,8 @@ async def on_post(req: ContextRequest, resp: Response, create_request: UserCreat ) save_user(req, db_user) + assert isinstance(req.context, ScimApiContext) + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, @@ -259,7 +266,9 @@ async def on_post(req: ContextRequest, resp: Response, create_request: UserCreat responses={204: {"description": "No Content"}}, ) async def on_delete(req: ContextRequest, scim_id: str) -> None: + assert isinstance(req.context, ScimApiContext) req.app.context.logger.info(f"Deleting user {scim_id}") + assert req.context.userdb is not None db_user = req.context.userdb.get_user_by_scim_id(scim_id=scim_id) req.app.context.logger.debug(f"Found user: {db_user}") if not db_user: @@ -279,6 +288,7 @@ async def on_delete(req: ContextRequest, scim_id: str) -> None: res = req.context.userdb.remove(db_user) + assert req.context.data_owner is not None add_api_event( context=req.app.context, data_owner=req.context.data_owner, diff --git a/src/eduid/scimapi/routers/utils/events.py b/src/eduid/scimapi/routers/utils/events.py index 1a7ecdbc6..39990616b 100644 --- a/src/eduid/scimapi/routers/utils/events.py +++ b/src/eduid/scimapi/routers/utils/events.py @@ -8,6 +8,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema, WeakVersion from eduid.common.utils import make_etag, urlappend +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.models.event import EventResponse, NutidEventExtensionV1, NutidEventResource from eduid.userdb.scimapi import EventLevel, EventStatus, ScimApiEvent, ScimApiEventResource, ScimApiResourceBase @@ -20,7 +21,7 @@ __author__ = "lundberg" -def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiEvent): +def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiEvent) -> EventResponse: location = req.app.context.resource_url(SCIMResourceType.EVENT, db_event.scim_id) meta = Meta( location=location, @@ -61,6 +62,10 @@ def db_event_to_response(req: ContextRequest, resp: Response, db_event: ScimApiE def get_scim_referenced(req: ContextRequest, resource: NutidEventResource) -> ScimApiResourceBase | None: + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None + assert req.context.groupdb is not None + assert req.context.invitedb is not None if resource.resource_type == SCIMResourceType.USER: return req.context.userdb.get_user_by_scim_id(str(resource.scim_id)) elif resource.resource_type == SCIMResourceType.GROUP: diff --git a/src/eduid/scimapi/routers/utils/groups.py b/src/eduid/scimapi/routers/utils/groups.py index 31a428ae4..ad438a5c7 100644 --- a/src/eduid/scimapi/routers/utils/groups.py +++ b/src/eduid/scimapi/routers/utils/groups.py @@ -7,6 +7,7 @@ from eduid.common.fastapi.context_request import ContextRequest from eduid.common.models.scim_base import Meta, SCIMResourceType, SCIMSchema from eduid.common.utils import make_etag +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.models.group import GroupMember, GroupResponse, NutidGroupExtensionV1 from eduid.scimapi.search import SearchFilter @@ -69,6 +70,10 @@ def filter_display_name( raise BadRequest(scim_type="invalidFilter", detail="Invalid displayName") req.app.context.logger.debug(f"Searching for group with display name {repr(filter.val)}") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None + assert skip is not None + assert limit is not None groups, count = req.context.groupdb.get_groups_by_property( key="display_name", value=filter.val, skip=skip, limit=limit ) @@ -90,6 +95,8 @@ def filter_lastmodified( _parsed = datetime.fromisoformat(filter.val) except Exception: raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None return req.context.groupdb.get_groups_by_last_modified(operator=filter.op, value=_parsed, skip=skip, limit=limit) @@ -107,6 +114,10 @@ def filter_extensions_data( raise BadRequest(scim_type="invalidFilter", detail="Unsupported extension search key") req.app.context.logger.debug(f"Searching for groups with {filter.attr} {filter.op} {repr(filter.val)}") + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None + assert skip is not None + assert limit is not None groups, count = req.context.groupdb.get_groups_by_property( key=filter.attr, value=filter.val, skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/routers/utils/invites.py b/src/eduid/scimapi/routers/utils/invites.py index 7fc405f23..1eb412074 100644 --- a/src/eduid/scimapi/routers/utils/invites.py +++ b/src/eduid/scimapi/routers/utils/invites.py @@ -14,6 +14,7 @@ from eduid.common.utils import get_short_hash, make_etag from eduid.queue.db import QueueItem, SenderInfo from eduid.queue.db.message import EduidInviteEmail +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.search import SearchFilter from eduid.scimapi.utils import get_unique_hash @@ -27,6 +28,8 @@ def create_signup_invite( req: ContextRequest, create_request: InviteCreateRequest, db_invite: ScimApiInvite ) -> SignupInvite: + assert isinstance(req.context, ScimApiContext) + assert req.context.data_owner is not None invite_reference = SCIMReference(data_owner=req.context.data_owner, scim_id=db_invite.scim_id) if create_request.nutid_invite_v1.send_email is False: @@ -64,7 +67,9 @@ def create_signup_invite( return signup_invite -def db_invite_to_response(req: Request, resp: Response, db_invite: ScimApiInvite, signup_invite: SignupInvite): +def db_invite_to_response( + req: Request, resp: Response, db_invite: ScimApiInvite, signup_invite: SignupInvite +) -> InviteResponse: location = req.app.context.url_for("Invites", db_invite.scim_id) meta = Meta( location=location, @@ -112,11 +117,13 @@ def db_invite_to_response(req: Request, resp: Response, db_invite: ScimApiInvite return scim_invite -def create_signup_ref(req: ContextRequest, db_invite: ScimApiInvite): +def create_signup_ref(req: ContextRequest, db_invite: ScimApiInvite) -> SCIMReference: + assert isinstance(req.context, ScimApiContext) + assert req.context.data_owner is not None return SCIMReference(data_owner=req.context.data_owner, scim_id=db_invite.scim_id) -def send_invite_mail(req: ContextRequest, signup_invite: SignupInvite): +def send_invite_mail(req: ContextRequest, signup_invite: SignupInvite) -> bool: try: email = [email.email for email in signup_invite.mail_addresses if email.primary][0] except IndexError: @@ -164,6 +171,8 @@ def save_invite( signup_invite_is_in_database: bool, ) -> None: try: + assert isinstance(req.context, ScimApiContext) + assert req.context.invitedb is not None req.context.invitedb.save(db_invite) except DuplicateKeyError as e: assert e.details is not None # please mypy @@ -187,6 +196,8 @@ def filter_lastmodified( raise BadRequest(scim_type="invalidFilter", detail="Unsupported operator") if not isinstance(filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) + assert req.context.invitedb is not None return req.context.invitedb.get_invites_by_last_modified( operator=filter.op, value=datetime.fromisoformat(filter.val), skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/routers/utils/status.py b/src/eduid/scimapi/routers/utils/status.py index 804026152..3815f9f5f 100644 --- a/src/eduid/scimapi/routers/utils/status.py +++ b/src/eduid/scimapi/routers/utils/status.py @@ -4,7 +4,7 @@ __author__ = "lundberg" -def check_mongo(req: ContextRequest, default_data_owner: str): +def check_mongo(req: ContextRequest, default_data_owner: str) -> bool | None: user_db = req.app.context.get_userdb(default_data_owner) group_db = req.app.context.get_groupdb(default_data_owner) try: @@ -18,7 +18,7 @@ def check_mongo(req: ContextRequest, default_data_owner: str): return False -def check_neo4j(req: ContextRequest, default_data_owner: str): +def check_neo4j(req: ContextRequest, default_data_owner: str) -> bool | None: group_db = req.app.context.get_groupdb(default_data_owner) try: # TODO: Implement is_healthy, check if there is a better way for neo4j diff --git a/src/eduid/scimapi/routers/utils/users.py b/src/eduid/scimapi/routers/utils/users.py index c3a68422c..24006f026 100644 --- a/src/eduid/scimapi/routers/utils/users.py +++ b/src/eduid/scimapi/routers/utils/users.py @@ -11,6 +11,7 @@ from eduid.common.models.scim_base import Email, Meta, Name, PhoneNumber, SCIMResourceType, SCIMSchema, SearchRequest from eduid.common.models.scim_user import Group, LinkedAccount, NutidUserExtensionV1, Profile, UserResponse from eduid.common.utils import make_etag +from eduid.scimapi.context_request import ScimApiContext from eduid.scimapi.exceptions import BadRequest from eduid.scimapi.routers.utils.events import add_api_event from eduid.scimapi.search import SearchFilter @@ -21,6 +22,8 @@ def get_user_groups(req: ContextRequest, db_user: ScimApiUser) -> list[Group]: """Return the groups for a user formatted as SCIM search sub-resources""" + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None user_groups = req.context.groupdb.get_groups_for_user_identifer(db_user.scim_id) groups = [] for group in user_groups: @@ -33,9 +36,13 @@ def get_user_groups(req: ContextRequest, db_user: ScimApiUser) -> list[Group]: def remove_user_from_all_groups(req: ContextRequest, db_user: ScimApiUser) -> None: """Remove a user from all groups""" # Remove user from groups + assert isinstance(req.context, ScimApiContext) + assert req.context.groupdb is not None + assert req.context.data_owner is not None for member_group in req.context.groupdb.get_groups_for_user_identifer(db_user.scim_id): # we need to get the full group object to get all the members group = req.context.groupdb.get_group_by_scim_id(str(member_group.scim_id)) + assert group is not None for member in group.graph.members.copy(): if member.identifier == str(db_user.scim_id): req.app.context.logger.debug( @@ -119,6 +126,8 @@ def db_user_to_response(req: ContextRequest, resp: Response, db_user: ScimApiUse def save_user(req: ContextRequest, db_user: ScimApiUser) -> None: try: + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None req.context.userdb.save(db_user) except DuplicateKeyError as e: assert e.details is not None # please mypy @@ -127,7 +136,7 @@ def save_user(req: ContextRequest, db_user: ScimApiUser) -> None: raise BadRequest(detail="Duplicated key error") -def acceptable_linked_accounts(value: list[LinkedAccount], environment: EduidEnvironment): +def acceptable_linked_accounts(value: list[LinkedAccount], environment: EduidEnvironment) -> bool: """ Setting linked_accounts through SCIM with limited issuer and value. If we need to support stepup with someone other than eduID this needs to change. @@ -161,6 +170,8 @@ def filter_externalid(req: ContextRequest, search_filter: SearchFilter) -> list[ if not isinstance(search_filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid externalId") + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None user = req.context.userdb.get_user_by_external_id(search_filter.val) if not user: @@ -176,6 +187,8 @@ def filter_lastmodified( raise BadRequest(scim_type="invalidFilter", detail="Unsupported operator") if not isinstance(search_filter.val, str): raise BadRequest(scim_type="invalidFilter", detail="Invalid datetime") + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None return req.context.userdb.get_users_by_last_modified( operator=search_filter.op, value=datetime.fromisoformat(search_filter.val), skip=skip, limit=limit ) @@ -195,6 +208,8 @@ def filter_profile_data( req.app.context.logger.debug( f"Searching for users with {search_filter.attr} {search_filter.op} {repr(search_filter.val)}" ) + assert isinstance(req.context, ScimApiContext) + assert req.context.userdb is not None users, count = req.context.userdb.get_user_by_profile_data( profile=profile, operator=search_filter.op, key=key, value=search_filter.val, skip=skip, limit=limit ) diff --git a/src/eduid/scimapi/testing.py b/src/eduid/scimapi/testing.py index 2b2e11212..ede63d7c7 100644 --- a/src/eduid/scimapi/testing.py +++ b/src/eduid/scimapi/testing.py @@ -172,7 +172,7 @@ def add_owner_to_group(self, group_identifier: str, user_identifier: str) -> Sci self.groupdb.save(group) return self.groupdb.get_group_by_scim_id(scim_id=group_identifier) - def tearDown(self): + def tearDown(self) -> None: super().tearDown() if self.userdb: self.userdb._drop_whole_collection() diff --git a/src/eduid/scimapi/tests/test_authn.py b/src/eduid/scimapi/tests/test_authn.py index f0f74fdbc..a1e033358 100644 --- a/src/eduid/scimapi/tests/test_authn.py +++ b/src/eduid/scimapi/tests/test_authn.py @@ -150,7 +150,7 @@ def test_requested_access_canonicalization(self) -> None: assert token.scopes == {domain} assert token.requested_access == [RequestedAccess(type=_requested_access_type, scope=domain)] - def test_invalid_requested_access_scope(self): + def test_invalid_requested_access_scope(self) -> None: # test too short domain name with pytest.raises(ValueError) as exc_info: AuthnBearerToken( @@ -160,6 +160,7 @@ def test_invalid_requested_access_scope(self): requested_access=[RequestedAccess(type=self.config.requested_access_type, scope=".se")], auth_source=AuthSource.CONFIG, ) + assert isinstance(exc_info.value, ValidationError) assert normalised_data(exc_info.value.errors(), exclude_keys=["url"]) == normalised_data( [ { @@ -172,7 +173,7 @@ def test_invalid_requested_access_scope(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_requested_access_not_for_us(self): + def test_requested_access_not_for_us(self) -> None: """Test with a 'requested_access' field with the wrong 'type' value.""" domain = "eduid.se" # test no canonization @@ -184,6 +185,7 @@ def test_requested_access_not_for_us(self): requested_access=[RequestedAccess(type="someone else", scope=domain)], auth_source=AuthSource.CONFIG, ) + assert isinstance(exc_info.value, ValidationError) assert normalised_data(exc_info.value.errors(), exclude_keys=["url"]) == normalised_data( [ { @@ -196,7 +198,7 @@ def test_requested_access_not_for_us(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_regular_token(self): + def test_regular_token(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" claims = { @@ -212,7 +214,7 @@ def test_regular_token(self): assert token.auth_source == AuthSource.CONFIG assert token.requested_access == [RequestedAccess(type="scim-api", scope=ScopeName("eduid.se"))] - def test_multiple_access_requests_including_us(self): + def test_multiple_access_requests_including_us(self) -> None: """Test when requested access has multiple requests. Only keep the request for the current resource.""" domain = "eduid.se" token = AuthnBearerToken( @@ -235,7 +237,7 @@ def test_multiple_access_requests_including_us(self): RequestedAccess(type="scim-api", scope=ScopeName(domain)), ] - def test_interaction_token(self): + def test_interaction_token(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" claims = { @@ -251,7 +253,7 @@ def test_interaction_token(self): assert token.auth_source == AuthSource.INTERACTION assert token.requested_access == [RequestedAccess(type="scim-api", scope=ScopeName("eduid.se"))] - def test_regular_token_with_canonisation(self): + def test_regular_token_with_canonisation(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = "eduid.se" domain_alias = "eduid.example.edu" @@ -261,7 +263,7 @@ def test_regular_token_with_canonisation(self): token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() == domain - def test_interaction_token_with_canonisation(self): + def test_interaction_token_with_canonisation(self) -> None: """Test the normal case. Login with access granted based on the single scope in the request.""" domain = DataOwnerName("eduid.se") domain_alias = ScopeName("eduid.example.edu") @@ -271,7 +273,7 @@ def test_interaction_token_with_canonisation(self): token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() == domain - def test_regular_token_upper_case(self): + def test_regular_token_upper_case(self) -> None: """ Test the normal case. Login with access granted based on the single scope in the request. Scope provided in upper-case in the request. @@ -283,21 +285,21 @@ def test_regular_token_upper_case(self): assert token.scopes == {domain} assert token.get_data_owner() == domain - def test_unknown_scope(self): + def test_unknown_scope(self) -> None: """Test login with a scope that has no data owner in the configuration.""" domain = "example.org" claims = {"version": 1, "scopes": [domain], "auth_source": "config"} token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() is None - def test_interaction_token_unknown_scope(self): + def test_interaction_token_unknown_scope(self) -> None: """Test login with a scope that has no data owner in the configuration.""" domain = "example.org" claims = {"version": 1, "saml_eppn": f"eppn{domain}", "auth_source": "interaction"} token = AuthnBearerToken(config=self.config, **claims) assert token.get_data_owner() is None - def test_regular_token_multiple_scopes(self): + def test_regular_token_multiple_scopes(self) -> None: """Test the normal case. Login with access granted based on the scope in the request that has a data owner in configuration (one extra scope provided in the request, named 'aaa' so it is checked first - and skipped). """ @@ -432,17 +434,17 @@ def _make_bearer_token(self, claims: Mapping[str, Any]) -> str: token.make_signed_token(jwk) return token.serialize() - def test_get_user_no_authn(self): + def test_get_user_no_authn(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api(db_user) self._assertScimError(response.json(), status=401, detail="No authentication header found") - def test_get_user_bogus_token(self): + def test_get_user_bogus_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api(db_user, bearer_token="not a jws token") self._assertScimError(response.json(), status=401, detail="Bearer token error") - def test_get_user_untrusted_token(self): + def test_get_user_untrusted_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self._get_user_from_api( @@ -457,7 +459,7 @@ def test_get_user_untrusted_token(self): self._assertScimError(response.json(), status=401, detail="Bearer token error") - def test_get_user_correct_token(self): + def test_get_user_correct_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) claims = {"scopes": ["eduid.se"], "version": 1, "auth_source": "config"} @@ -472,7 +474,7 @@ def test_get_user_correct_token(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_get_user_interaction_token(self): + def test_get_user_interaction_token(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) db_group = self.add_group_with_member( group_identifier=str(uuid4()), @@ -480,6 +482,7 @@ def test_get_user_interaction_token(self): user_identifier=str(db_user.scim_id), ) + assert self.groupdb is not None claims = { "saml_eppn": "eppn@eduid.se", "version": 1, @@ -505,7 +508,7 @@ def test_get_user_interaction_token(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_get_user_data_owner_not_configured(self): + def test_get_user_data_owner_not_configured(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) claims = {"scopes": ["not_configured.se"], "version": 1, "auth_source": "config"} token = self._make_bearer_token(claims=claims) diff --git a/src/eduid/scimapi/tests/test_context.py b/src/eduid/scimapi/tests/test_context.py index 00ac60835..d88ce8f5e 100644 --- a/src/eduid/scimapi/tests/test_context.py +++ b/src/eduid/scimapi/tests/test_context.py @@ -6,12 +6,12 @@ class TestContext(ScimApiTestCase): - def test_init(self): + def test_init(self) -> None: config = load_config(typ=ScimApiConfig, app_name="scimapi", ns="api", test_config=self.test_config) ctx = Context(config=config) self.assertEqual(ctx.base_url, "http://localhost:8000") - def test_load_many_data_owners(self): + def test_load_many_data_owners(self) -> None: # Add 99 more data owners to the config for i in range(99): self.test_config["data_owners"][f"owner{i}"] = {"db_name": f"owner_{i}"} diff --git a/src/eduid/scimapi/tests/test_groupdb.py b/src/eduid/scimapi/tests/test_groupdb.py index 6a6ce1f36..0b6c8c281 100644 --- a/src/eduid/scimapi/tests/test_groupdb.py +++ b/src/eduid/scimapi/tests/test_groupdb.py @@ -23,8 +23,9 @@ def setUp(self) -> None: for i in range(9): self.add_group(uuid4(), f"Test Group-{i}") - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.groupdb is not None self.groupdb._drop_whole_collection() def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtensions | None = None) -> ScimApiGroup: @@ -37,22 +38,27 @@ def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtension logger.info(f"TEST saved group {group}") return group - def test_full_search(self): + def test_full_search(self) -> None: + assert self.groupdb is not None groups = self.groupdb.get_groups() self.assertEqual(len(groups), 9) - def test_documents_and_count_first_page(self): + def test_documents_and_count_first_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, limit=3) - [logger.info(f"Group {x}") for x in groups] + for x in groups: + logger.info(f"Group {x}") self.assertEqual(len(groups), 3) self.assertEqual(count, 9) - def test_documents_and_count_last_page(self): + def test_documents_and_count_last_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, skip=6, limit=3) self.assertEqual(len(groups), 3) self.assertEqual(count, 9) - def test_documents_and_count_partial_last_page(self): + def test_documents_and_count_partial_last_page(self) -> None: + assert self.groupdb is not None groups, count = self.groupdb._get_documents_and_count_by_filter(spec={}, skip=8, limit=3) self.assertEqual(len(groups), 1) self.assertEqual(count, 9) diff --git a/src/eduid/scimapi/tests/test_login.py b/src/eduid/scimapi/tests/test_login.py index c5167836f..162cd76e0 100644 --- a/src/eduid/scimapi/tests/test_login.py +++ b/src/eduid/scimapi/tests/test_login.py @@ -15,12 +15,12 @@ def _get_config(self) -> dict: config["login_enabled"] = True return config - def test_get_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_get_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) self._assertResponse(response) - def test_use_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_use_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) token = response.headers.get("Authorization") headers = { "Content-Type": "application/scim+json", @@ -36,6 +36,6 @@ class TestLoginResourceNotEnabled(ScimApiTestCase): def setUp(self) -> None: super().setUp() - def test_get_token(self): - response = self.client.post(url="/login", data=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) + def test_get_token(self) -> None: + response = self.client.post(url="/login", content=json.dumps({"data_owner": "eduid.se"}), headers=self.headers) assert response.status_code == 404 diff --git a/src/eduid/scimapi/tests/test_notifications.py b/src/eduid/scimapi/tests/test_notifications.py index fadf9d5e5..d24741ebf 100644 --- a/src/eduid/scimapi/tests/test_notifications.py +++ b/src/eduid/scimapi/tests/test_notifications.py @@ -19,7 +19,7 @@ def _get_config(self) -> dict[str, Any]: config["data_owners"]["eduid.se"]["notify"] = ["https://example.org/notify"] return config - def test_create_user_notification(self): + def test_create_user_notification(self) -> None: assert len(self._get_notifications()) == 0 req = {"schemas": [SCIMSchema.CORE_20_USER.value], "externalId": "test-id-1"} @@ -28,7 +28,7 @@ def test_create_user_notification(self): assert len(self._get_notifications()) == 1 - def test_create_group_notification(self): + def test_create_group_notification(self) -> None: assert len(self._get_notifications()) == 0 req = {"schemas": [SCIMSchema.CORE_20_GROUP.value], "externalId": "test-id-1", "displayName": "Test Group"} @@ -37,7 +37,7 @@ def test_create_group_notification(self): assert len(self._get_notifications()) == 1 - def test_create_event_notification(self): + def test_create_event_notification(self) -> None: assert len(self._get_notifications()) == 0 user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") diff --git a/src/eduid/scimapi/tests/test_profile.py b/src/eduid/scimapi/tests/test_profile.py index a80515756..e058eafbe 100644 --- a/src/eduid/scimapi/tests/test_profile.py +++ b/src/eduid/scimapi/tests/test_profile.py @@ -4,7 +4,7 @@ class TestProfile(TestCase): - def test_parse(self): + def test_parse(self) -> None: displayname = "Musse Pigg" data = {"profiles": {"student": {"attributes": {"displayName": displayname}}}} extension = NutidUserExtensionV1.model_validate(data) diff --git a/src/eduid/scimapi/tests/test_scimbase.py b/src/eduid/scimapi/tests/test_scimbase.py index 6ca4a62ab..68b78e8d7 100644 --- a/src/eduid/scimapi/tests/test_scimbase.py +++ b/src/eduid/scimapi/tests/test_scimbase.py @@ -34,7 +34,7 @@ def test_base_response(self) -> None: loaded_base = BaseResponse.parse_raw(base_dump) assert normalised_data(base.dict()) == normalised_data(loaded_base.dict()) - def test_hashable_subresources(self): + def test_hashable_subresources(self) -> None: a = { "$ref": "http://localhost:8000/Users/78130160-b63d-4303-99cd-73767e2a999f", "display": "Test User 1 (updated)", diff --git a/src/eduid/scimapi/tests/test_scimevent.py b/src/eduid/scimapi/tests/test_scimevent.py index 931ddc9dc..8b05571c7 100644 --- a/src/eduid/scimapi/tests/test_scimevent.py +++ b/src/eduid/scimapi/tests/test_scimevent.py @@ -26,8 +26,9 @@ class TestEventResource(ScimApiTestCase): def setUp(self) -> None: super().setUp() - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.eventdb self.eventdb._drop_whole_collection() def _create_event(self, event: dict[str, Any], expect_success: bool = True) -> EventApiResult: @@ -65,7 +66,7 @@ def _assertEventUpdateSuccess(self, req: Mapping, response: Response, event: Sci self._assertScimResponseProperties(response, resource=event, expected_schemas=expected_schemas) - def test_create_event(self): + def test_create_event(self) -> None: user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") event = { "resource": { @@ -80,6 +81,7 @@ def test_create_event(self): result = self._create_event(event=event) # check that the create resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, scim_id=user.scim_id) assert len(events) == 1 db_event = events[0] @@ -92,9 +94,10 @@ def test_create_event(self): assert db_event.data == event["data"] # Verify what is returned in the parsed_response assert result.parsed_response.id == db_event.scim_id + assert result.request self._assertEventUpdateSuccess(req=result.request, response=result.response, event=db_event) - def test_create_and_fetch_event(self): + def test_create_and_fetch_event(self) -> None: user = self.add_user(identifier=str(uuid4()), external_id="test@example.org") event = { "resource": { @@ -109,6 +112,7 @@ def test_create_and_fetch_event(self): created = self._create_event(event=event) # check that the creation resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, scim_id=user.scim_id) assert len(events) == 1 db_event = events[0] diff --git a/src/eduid/scimapi/tests/test_scimgroup.py b/src/eduid/scimapi/tests/test_scimgroup.py index 221e86b2c..5fada4dab 100644 --- a/src/eduid/scimapi/tests/test_scimgroup.py +++ b/src/eduid/scimapi/tests/test_scimgroup.py @@ -56,8 +56,9 @@ def setUp(self) -> None: super().setUp() self.groupdb = self.context.get_groupdb(DataOwnerName("eduid.se")) - def tearDown(self): + def tearDown(self) -> None: super().tearDown() + assert self.groupdb self.groupdb._drop_whole_collection() def add_group(self, scim_id: UUID, display_name: str, extensions: GroupExtensions | None = None) -> ScimApiGroup: @@ -182,12 +183,13 @@ def _assertGroupUpdateSuccess(self, req: Mapping, response: Response, group: Sci class TestGroupResource_GET(TestGroupResource): - def test_get_groups(self): + def test_get_groups(self) -> None: for i in range(9): self.add_group(uuid4(), f"Test Group {i}") response = self.client.get(url="/Groups", headers=self.headers) self.assertEqual([SCIMSchema.API_MESSAGES_20_LIST_RESPONSE.value], response.json().get("schemas")) resources = response.json().get("Resources") + assert self.groupdb expected_num_resources = self.groupdb.graphdb.db.count_nodes() self.assertEqual( expected_num_resources, @@ -200,18 +202,18 @@ def test_get_groups(self): f"Response totalResults does not match number of groups in the database: {expected_num_resources}", ) - def test_get_group(self): + def test_get_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") response = self.client.get(url=f"/Groups/{db_group.scim_id}", headers=self.headers) self._assertGroupUpdateSuccess({"members": []}, response, db_group) - def test_get_group_not_found(self): + def test_get_group_not_found(self) -> None: response = self.client.get(url=f"/Groups/{uuid4()}", headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") class TestGroupResource_POST(TestGroupResource): - def test_create_group(self): + def test_create_group(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], "displayName": "Test Group 1", @@ -221,6 +223,8 @@ def test_create_group(self): response = self.client.post(url="/Groups/", json=req, headers=self.headers) # Load the created group from the database, ensuring it was in fact created + assert self.groupdb + assert isinstance(req["displayName"], str) # please mypy _groups, _count = self.groupdb.get_groups_by_property("display_name", req["displayName"]) self.assertEqual(1, _count, "More or less than one group found in the database after create") db_group = _groups[0] @@ -228,13 +232,14 @@ def test_create_group(self): self._assertGroupUpdateSuccess(req, response, db_group) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing displayName req = {"schemas": [SCIMSchema.CORE_20_GROUP.value], "members": []} response = self.client.post(url="/Groups/", json=req, headers=self.headers) @@ -253,7 +258,7 @@ def test_schema_violation(self): class TestGroupResource_PUT(TestGroupResource): - def test_update_group(self): + def test_update_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") user = self.add_user(identifier=str(uuid4()), external_id="not-used") @@ -280,7 +285,7 @@ def test_update_group(self): self._assertGroupUpdateSuccess(req, response, db_group) - def test_update_existing_group(self): + def test_update_existing_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") user = self.add_user(identifier=str(uuid4()), external_id="not-used") @@ -319,13 +324,14 @@ def test_update_existing_group(self): self._assertGroupUpdateSuccess(req, response, db_group) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group.scim_id) assert len(events) == 2 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.UPDATED.value - def test_add_member_to_existing_group(self): + def test_add_member_to_existing_group(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") user = self.add_user(identifier=str(uuid4()), external_id="not-used") members = [ @@ -361,15 +367,18 @@ def test_add_member_to_existing_group(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertGroupUpdateSuccess(req, response, db_group) - def test_removing_group_member(self): + def test_removing_group_member(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") subgroup = self.add_group(uuid4(), "Test Group 2") db_group = self.add_member(db_group, subgroup, "Test User") user = self.add_user(identifier=str(uuid4()), external_id="not-used") db_group = self.add_member(db_group, user, "Test User") + assert self.groupdb + # Load group to verify it has two members _g1 = self.groupdb.get_group_by_scim_id(str(db_group.scim_id)) + assert _g1 self.assertEqual(2, len(_g1.graph.members), "Group loaded from database does not have two members") self.assertEqual(1, len(_g1.graph.member_users), "Group loaded from database does not have one member user") self.assertEqual(1, len(_g1.graph.member_groups), "Group loaded from database does not have one member group") @@ -396,11 +405,12 @@ def test_removing_group_member(self): # Load group to verify it has one less member now _g2 = self.groupdb.get_group_by_scim_id(str(db_group.scim_id)) + assert _g2 self.assertEqual(1, len(_g2.graph.members), "Group loaded from database does not have two members") self.assertEqual(1, len(_g2.graph.member_users), "Group loaded from database does not have one member user") self.assertEqual(0, len(_g2.graph.member_groups), "Group loaded from database does not have one member group") - def test_update_group_id_mismatch(self): + def test_update_group_id_mismatch(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -411,7 +421,7 @@ def test_update_group_id_mismatch(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail="Id mismatch") - def test_update_group_not_found(self): + def test_update_group_not_found(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], "id": str(uuid4()), @@ -421,7 +431,7 @@ def test_update_group_not_found(self): response = self.client.put(url=f'/Groups/{req["id"]}', json=req, headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") - def test_version_mismatch(self): + def test_version_mismatch(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -432,7 +442,7 @@ def test_version_mismatch(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail="Version mismatch") - def test_update_group_member_does_not_exist(self): + def test_update_group_member_does_not_exist(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") _user_scim_id = str(uuid4()) members = [ @@ -452,7 +462,7 @@ def test_update_group_member_does_not_exist(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail=f"User {_user_scim_id} not found") - def test_update_group_subgroup_does_not_exist(self): + def test_update_group_subgroup_does_not_exist(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") _subgroup_scim_id = str(uuid4()) members = [ @@ -472,7 +482,7 @@ def test_update_group_subgroup_does_not_exist(self): response = self.client.put(url=f"/Groups/{db_group.scim_id}", json=req, headers=self.headers) self._assertScimError(response.json(), detail=f"Group {_subgroup_scim_id} not found") - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing displayName req = { "schemas": [SCIMSchema.CORE_20_GROUP.value], @@ -494,20 +504,22 @@ def test_schema_violation(self): class TestGroupResource_DELETE(TestGroupResource): - def test_delete_group(self): + def test_delete_group(self) -> None: group = self.add_group(uuid4(), "Test Group 1") + assert self.groupdb + assert self.eventdb # Verify we can find the group in the database db_group1 = self.groupdb.get_group_by_scim_id(str(group.scim_id)) - self.assertIsNotNone(db_group1) + assert db_group1 is not None self.headers["IF-MATCH"] = make_etag(group.version) response = self.client.delete(url=f"/Groups/{group.scim_id}", headers=self.headers) self.assertEqual(204, response.status_code) # Verify the group is no longer in the database - db_group2 = self.groupdb.get_group_by_scim_id(group.scim_id) - self.assertIsNone(db_group2) + db_group2 = self.groupdb.get_group_by_scim_id(str(group.scim_id)) + assert db_group2 is None # check that the action resulted in an event in the database events = self.eventdb.get_events_by_resource(SCIMResourceType.GROUP, db_group1.scim_id) @@ -516,50 +528,51 @@ def test_delete_group(self): assert event.resource.external_id is None assert event.data["status"] == EventStatus.DELETED.value - def test_version_mismatch(self): + def test_version_mismatch(self) -> None: group = self.add_group(uuid4(), "Test Group 1") self.headers["IF-MATCH"] = make_etag(ObjectId()) response = self.client.delete(url=f"/Groups/{group.scim_id}", headers=self.headers) self._assertScimError(response.json(), detail="Version mismatch") - def test_group_not_found(self): + def test_group_not_found(self) -> None: response = self.client.delete(url=f"/Groups/{uuid4()}", headers=self.headers) self._assertScimError(response.json(), status=404, detail="Group not found") class TestGroupSearchResource(TestGroupResource): - def test_search_group_display_name(self): + def test_search_group_display_name(self) -> None: db_group = self.add_group(uuid4(), "Test Group 1") self.add_group(uuid4(), "Test Group 2") self._perform_search(filter='displayName eq "Test Group 1"', expected_group=db_group) - def test_search_group_display_name_not_found(self): + def test_search_group_display_name_not_found(self) -> None: self._perform_search(filter='displayName eq "Test No Such Group"', expected_total_results=0) - def test_search_group_display_name_bad_operator(self): + def test_search_group_display_name_bad_operator(self) -> None: json = self._perform_search(filter="displayName lt 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Unsupported operator") - def test_search_group_display_name_not_string(self): + def test_search_group_display_name_not_string(self) -> None: json = self._perform_search(filter="displayName eq 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Invalid displayName") - def test_search_group_unknown_attribute(self): + def test_search_group_unknown_attribute(self) -> None: json = self._perform_search(filter="no_such_attribute lt 1", return_json=True) self._assertScimError(json, scim_type="invalidFilter", detail="Can't filter on attribute no_such_attribute") - def test_search_group_start_index(self): + def test_search_group_start_index(self) -> None: for i in range(9): self.add_group(uuid4(), "Test Group") self._perform_search( filter='displayName eq "Test Group"', start=5, expected_num_resources=5, expected_total_results=9 ) - def test_search_group_count(self): + def test_search_group_count(self) -> None: for i in range(9): self.add_group(uuid4(), "Test Group") + assert self.groupdb groups = self.groupdb.get_groups() self.assertEqual(len(groups), 9) @@ -567,24 +580,24 @@ def test_search_group_count(self): filter='displayName eq "Test Group"', start=1, count=5, expected_num_resources=5, expected_total_results=9 ) - def test_search_group_extension_data_attribute_str(self): + def test_search_group_extension_data_attribute_str(self) -> None: ext = GroupExtensions(data={"some_key": "20072009"}) db_group = self.add_group(uuid4(), "Test Group with extension", extensions=ext) self._perform_search(filter='extensions.data.some_key eq "20072009"', expected_group=db_group) - def test_search_group_extension_data_bad_op(self): + def test_search_group_extension_data_bad_op(self) -> None: json = self._perform_search(filter='extensions.data.some_key XY "20072009"', return_json=True) self._assertScimError(json, detail="Unsupported operator") - def test_search_group_extension_data_invalid_key(self): + def test_search_group_extension_data_invalid_key(self) -> None: json = self._perform_search(filter='extensions.data.some.key eq "20072009"', return_json=True) self._assertScimError(json, detail="Unsupported extension search key") - def test_search_group_extension_data_not_found(self): + def test_search_group_extension_data_not_found(self) -> None: self._perform_search(filter='extensions.data.some_key eq "20072009"', expected_num_resources=0) - def test_search_group_extension_data_attribute_int(self): + def test_search_group_extension_data_attribute_int(self) -> None: ext1 = GroupExtensions(data={"some_key": 20072009}) group = self.add_group(uuid4(), "Test Group with extension", extensions=ext1) @@ -594,7 +607,7 @@ def test_search_group_extension_data_attribute_int(self): self._perform_search(filter="extensions.data.some_key eq 20072009", expected_group=group) - def test_search_group_last_modified(self): + def test_search_group_last_modified(self) -> None: group1 = self.add_group(uuid4(), "Test Group 1") group2 = self.add_group(uuid4(), "Test Group 2") self.assertGreater(group2.last_modified, group1.last_modified) @@ -605,15 +618,15 @@ def test_search_group_last_modified(self): self._perform_search(filter=f'meta.lastModified gt "{group1.last_modified.isoformat()}"', expected_group=group2) - def test_search_group_last_modified_invalid_datetime_1(self): + def test_search_group_last_modified_invalid_datetime_1(self) -> None: json = self._perform_search(filter="meta.lastModified ge 1", return_json=True) self._assertScimError(json, detail="Invalid datetime") - def test_search_group_last_modified_invalid_datetime_2(self): + def test_search_group_last_modified_invalid_datetime_2(self) -> None: json = self._perform_search(filter='meta.lastModified ge "2020-05-12_15:36:99+00:00"', return_json=True) self._assertScimError(json, detail="Invalid datetime") - def test_schema_violation(self): + def test_schema_violation(self) -> None: # request missing filter req = { "schemas": [SCIMSchema.API_MESSAGES_20_SEARCH_REQUEST.value], @@ -634,7 +647,7 @@ def test_schema_violation(self): class TestGroupExtensionData(TestGroupResource): - def test_nutid_extension(self): + def test_nutid_extension(self) -> None: display_name = "Test Group with Nutid extension" nutid_data = {"data": {"testing": "certainly"}} req = { @@ -648,7 +661,9 @@ def test_nutid_extension(self): # Load the newly created group from the database in order to validate the SCIM parsed_response better scim_id = post_resp.json().get("id") self.assertIsNotNone(scim_id, "Group creation parsed_response id not present") + assert self.groupdb db_group = self.groupdb.get_group_by_scim_id(scim_id) + assert db_group expected_schemas = [SCIMSchema.CORE_20_GROUP.value, SCIMSchema.NUTID_GROUP_V1.value] self._assertScimResponseProperties(post_resp, db_group, expected_schemas=expected_schemas) self.assertEqual([], post_resp.json().get("members"), "Group was not expected to have members") @@ -686,7 +701,9 @@ def test_nutid_extension(self): get_resp2 = self.client.get(url=f"/Groups/{scim_id}", headers=self.headers) self.assertEqual(put_resp.json(), get_resp2.json()) + assert self.groupdb db_group = self.groupdb.get_group_by_scim_id(scim_id) + assert db_group self._assertScimResponseProperties(get_resp2, db_group, expected_schemas=expected_schemas) self.assertEqual([], get_resp2.json().get("members"), "Group was not expected to have members") diff --git a/src/eduid/scimapi/tests/test_sciminvite.py b/src/eduid/scimapi/tests/test_sciminvite.py index 71978c0c9..ac3784e74 100644 --- a/src/eduid/scimapi/tests/test_sciminvite.py +++ b/src/eduid/scimapi/tests/test_sciminvite.py @@ -1,7 +1,7 @@ import json import logging import unittest -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import copy from dataclasses import asdict from datetime import datetime, timedelta @@ -56,7 +56,7 @@ def setUp(self) -> None: "profiles": {"student": {"attributes": {"displayName": "Test"}}}, } - def test_load_invite(self): + def test_load_invite(self) -> None: invite = ScimApiInvite.from_dict(self.invite_doc1) # test to-dict+from-dict consistency invite2 = ScimApiInvite.from_dict(invite.to_dict()) @@ -64,8 +64,9 @@ def test_load_invite(self): assert asdict(invite) == asdict(invite2) assert invite.to_dict() == invite2.to_dict() - def test_to_sciminvite_response(self): + def test_to_sciminvite_response(self) -> None: db_invite = ScimApiInvite.from_dict(self.invite_doc1) + assert db_invite meta = Meta( location=f"http://example.org/Invites/{db_invite.scim_id}", resource_type=SCIMResourceType.INVITE, @@ -74,6 +75,8 @@ def test_to_sciminvite_response(self): version=db_invite.version, ) + assert db_invite.emails[0].primary is not None + assert db_invite.emails[1].primary is not None signup_invite = SignupInvite( invite_type=InviteType.SCIM, invite_reference=SCIMReference(data_owner="test_data_owner", scim_id=db_invite.scim_id), @@ -344,7 +347,7 @@ def _perform_search( resources = response.json().get("Resources") return resources - def test_create_invite(self): + def test_create_invite(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -382,20 +385,24 @@ def test_create_invite(self): response = self.client.post(url="/Invites/", json=req, headers=self.headers) self._assertResponse(response, status_code=201) + assert self.invitedb db_invite = self.invitedb.get_invite_by_scim_id(response.json().get("id")) + assert db_invite reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) signup_invite = self.signup_invitedb.get_invite_by_reference(reference) + assert signup_invite self._assertUpdateSuccess(req, response, db_invite, signup_invite) self.assertEqual(1, self.messagedb.db_count()) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.INVITE, db_invite.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_create_invite_missing_mandatory_attributes(self): + def test_create_invite_missing_mandatory_attributes(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -408,7 +415,7 @@ def test_create_invite_missing_mandatory_attributes(self): }, } - req1 = copy(req) + req1: MutableMapping[str, Any] = copy(req) del req1[SCIMSchema.NUTID_INVITE_V1.value]["inviterName"] response = self.client.post(url="/Invites/", json=req1, headers=self.headers) self._assertScimError( @@ -425,7 +432,7 @@ def test_create_invite_missing_mandatory_attributes(self): exclude_keys=["input", "url"], ) - req2 = copy(req) + req2: MutableMapping[str, Any] = copy(req) del req2[SCIMSchema.NUTID_INVITE_V1.value]["sendEmail"] response = self.client.post(url="/Invites/", json=req2, headers=self.headers) self._assertScimError( @@ -442,7 +449,7 @@ def test_create_invite_missing_mandatory_attributes(self): exclude_keys=["input", "url"], ) - def test_create_invite_do_not_send_email(self): + def test_create_invite_do_not_send_email(self) -> None: req = { "schemas": [ SCIMSchema.NUTID_INVITE_CORE_V1.value, @@ -479,13 +486,16 @@ def test_create_invite_do_not_send_email(self): response = self.client.post(url="/Invites/", json=req, headers=self.headers) self._assertResponse(response, status_code=201) + assert self.invitedb db_invite = self.invitedb.get_invite_by_scim_id(response.json().get("id")) + assert db_invite reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) signup_invite = self.signup_invitedb.get_invite_by_reference(reference) + assert signup_invite self._assertUpdateSuccess(req, response, db_invite, signup_invite) self.assertEqual(0, self.messagedb.db_count()) - def test_get_invite(self): + def test_get_invite(self) -> None: db_invite = self.add_invite() response = self.client.get(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) expected_schemas = [ @@ -495,7 +505,7 @@ def test_get_invite(self): ] self._assertScimResponseProperties(response, resource=db_invite, expected_schemas=expected_schemas) - def test_update_invite(self): + def test_update_invite(self) -> None: # TODO: For now we only support updating completed db_invite = self.add_invite() invite = self.client.get(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) @@ -508,7 +518,9 @@ def test_update_invite(self): response = self.client.put(url=f"/Invites/{db_invite.scim_id}", json=update_req, headers=self.headers) assert response.status_code == 200 + assert self.invitedb updated_invite = self.invitedb.get_invite_by_scim_id(str(db_invite.scim_id)) + assert updated_invite assert updated_invite.completed is not None expected_schemas = [ @@ -518,21 +530,23 @@ def test_update_invite(self): ] self._assertScimResponseProperties(response, resource=db_invite, expected_schemas=expected_schemas) - def test_delete_invite(self): + def test_delete_invite(self) -> None: db_invite = self.add_invite() self.headers["IF-MATCH"] = make_etag(db_invite.version) self.client.delete(url=f"/Invites/{db_invite.scim_id}", headers=self.headers) reference = SCIMReference(data_owner=self.data_owner, scim_id=db_invite.scim_id) + assert self.invitedb self.assertIsNone(self.invitedb.get_invite_by_scim_id(str(db_invite.scim_id))) self.assertIsNone(self.signup_invitedb.get_invite_by_reference(reference)) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.INVITE, db_invite.scim_id) assert len(events) == 1 event = events[0] assert event.data["status"] == EventStatus.DELETED.value - def test_search_user_last_modified(self): + def test_search_user_last_modified(self) -> None: db_invite1 = self.add_invite() db_invite2 = self.add_invite(data={"invite_code": "another_invite_code"}, update=True) self.assertGreater(db_invite2.last_modified, db_invite1.last_modified) diff --git a/src/eduid/scimapi/tests/test_scimuser.py b/src/eduid/scimapi/tests/test_scimuser.py index 17ac5d915..d65ed0a69 100644 --- a/src/eduid/scimapi/tests/test_scimuser.py +++ b/src/eduid/scimapi/tests/test_scimuser.py @@ -19,7 +19,7 @@ from eduid.common.utils import make_etag from eduid.scimapi.testing import ScimApiTestCase from eduid.scimapi.utils import filter_none -from eduid.userdb.scimapi import EventStatus, ScimApiLinkedAccount +from eduid.userdb.scimapi import EventStatus, ScimApiGroup, ScimApiLinkedAccount from eduid.userdb.scimapi.userdb import ScimApiProfile, ScimApiUser logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def setUp(self) -> None: "profiles": {"student": {"attributes": {"displayName": "Test"}}}, } - def test_load_old_user(self): + def test_load_old_user(self) -> None: user = ScimApiUser.from_dict(self.user_doc1) self.assertEqual(user.profiles["student"].attributes["displayName"], "Test") @@ -57,7 +57,7 @@ def test_load_old_user(self): user2 = ScimApiUser.from_dict(user.to_dict()) self.assertEqual(asdict(user), asdict(user2)) - def test_to_scimuser_doc(self): + def test_to_scimuser_doc(self) -> None: db_user = ScimApiUser.from_dict(self.user_doc1) meta = Meta( location=f"http://example.org/Users/{db_user.scim_id}", @@ -114,7 +114,7 @@ def test_to_scimuser_doc(self): loaded_user_response = json.loads(user_response_json) assert loaded_user_response == expected - def test_to_scimuser_no_external_id(self): + def test_to_scimuser_no_external_id(self) -> None: user_doc2 = { "_id": ObjectId("5e81c5f849ac2cd87580e500"), "scim_id": "a7851d21-eab9-4caa-ba5d-49653d65c452", @@ -167,7 +167,7 @@ def test_to_scimuser_no_external_id(self): loaded_user_response = json.loads(user_response_json) assert loaded_user_response == expected - def test_bson_serialization(self): + def test_bson_serialization(self) -> None: user = ScimApiUser.from_dict(self.user_doc1) x = bson.encode(user.to_dict()) self.assertTrue(x) @@ -278,7 +278,7 @@ def _update_user( class TestUserResource(ScimApiTestUserResourceBase): - def test_get_user(self): + def test_get_user(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) response = self.client.get(url=f"/Users/{db_user.scim_id}", headers=self.headers) @@ -287,11 +287,11 @@ def test_get_user(self): } self._assertUserUpdateSuccess(_req, response, db_user) - def test_create_users_with_no_external_id(self): + def test_create_users_with_no_external_id(self) -> None: self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) - def test_create_user(self): + def test_create_user(self) -> None: req = { "externalId": "test-id-1", "name": {"familyName": "Testsson", "givenName": "Test", "middleName": "Testaren"}, @@ -308,19 +308,23 @@ def test_create_user(self): result = self._create_user(req) # Load the created user from the database, ensuring it was in fact created + assert self.userdb + assert isinstance(req["externalId"], str) db_user = self.userdb.get_user_by_external_id(req["externalId"]) + assert db_user self.assertIsNotNone(db_user, "Created user not found in the database") self._assertUserUpdateSuccess(result.request, result.response, db_user) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, db_user.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.CREATED.value - def test_create_and_update_user(self): + def test_create_and_update_user(self) -> None: """Test that creating a user and then updating it without changes only results in one event""" req = { "externalId": "test-id-1", @@ -335,8 +339,10 @@ def test_create_and_update_user(self): }, } result1 = self._create_user(req) + assert result1.parsed_response # check that the action resulted in an event in the database + assert self.eventdb events1 = self.eventdb.get_events_by_resource(SCIMResourceType.USER, result1.parsed_response.id) assert len(events1) == 1 event = events1[0] @@ -345,6 +351,7 @@ def test_create_and_update_user(self): # Update the user without making any changes result2 = self._update_user(req, result1.parsed_response.id, result1.parsed_response.meta.version) + assert result2.parsed_response # Make sure the version wasn't updated assert result1.parsed_response.meta.version == result2.parsed_response.meta.version # Make sure no additional event was created @@ -352,7 +359,7 @@ def test_create_and_update_user(self): assert len(events2) == 1 assert events1 == events2 - def test_create_user_no_external_id(self): + def test_create_user_no_external_id(self) -> None: req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], SCIMSchema.NUTID_USER_V1.value: { @@ -366,12 +373,14 @@ def test_create_user_no_external_id(self): self._assertResponse(response, status_code=201) # Load the created user from the database, ensuring it was in fact created + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self.assertIsNotNone(db_user, "Created user not found in the database") self._assertUserUpdateSuccess(req, response, db_user) - def test_create_user_duplicated_external_id(self): + def test_create_user_duplicated_external_id(self) -> None: external_id = "test-id-1" # Create an existing user in the db self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) @@ -388,8 +397,11 @@ def test_create_user_duplicated_external_id(self): response.json(), schemas=["urn:ietf:params:scim:api:messages:2.0:Error"], detail="externalID must be unique" ) - def test_update_user(self): - db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) + def test_update_user(self) -> None: + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile} + ) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -408,17 +420,20 @@ def test_update_user(self): self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.put(url=f"/Users/{db_user.scim_id}", json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, db_user.scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == req["externalId"] assert event.data["status"] == EventStatus.UPDATED.value - def test_update_user_change_properties(self): + def test_update_user_change_properties(self) -> None: # Create the user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], @@ -460,11 +475,14 @@ def test_update_user_change_properties(self): response = self.client.put(url=f'/Users/{create_response.json()["id"]}', json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) - def test_update_user_set_external_id(self): - db_user = self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) + def test_update_user_set_external_id(self) -> None: + db_user: ScimApiUser | None = self.add_user(identifier=str(uuid4()), profiles={"test": self.test_profile}) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -479,10 +497,12 @@ def test_update_user_set_external_id(self): self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.put(url=f"/Users/{db_user.scim_id}", json=req, headers=self.headers) self._assertResponse(response) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(response.json()["id"]) + assert db_user self._assertUserUpdateSuccess(req, response, db_user) - def test_update_user_duplicated_external_id(self): + def test_update_user_duplicated_external_id(self) -> None: external_id = "test-id-1" # Create two existing users with different external_id self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) @@ -504,40 +524,52 @@ def test_update_user_duplicated_external_id(self): response.json(), schemas=["urn:ietf:params:scim:api:messages:2.0:Error"], detail="externalID must be unique" ) - def test_delete_user(self): + def test_delete_user(self) -> None: external_id = "test-id-1" - db_user = self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile} + ) + assert db_user user_scim_id = db_user.scim_id self.headers["IF-MATCH"] = make_etag(db_user.version) response = self.client.delete(url=f"/Users/{db_user.scim_id}", headers=self.headers) self._assertResponse(response, status_code=204) # No content - db_user = self.userdb.get_user_by_scim_id(user_scim_id) + assert self.userdb + db_user = self.userdb.get_user_by_scim_id(str(user_scim_id)) assert db_user is None # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, user_scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == external_id assert event.data["status"] == EventStatus.DELETED.value - def test_delete_user_with_groups(self): + def test_delete_user_with_groups(self) -> None: external_id = "test-id-1" - db_user = self.add_user(identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile}) + db_user: ScimApiUser | None = self.add_user( + identifier=str(uuid4()), external_id=external_id, profiles={"test": self.test_profile} + ) + assert db_user user_scim_id = db_user.scim_id - group1 = self.add_group_with_member( + group1: ScimApiGroup | None = self.add_group_with_member( group_identifier=str(uuid4()), display_name="Group 1", user_identifier=str(user_scim_id) ) + assert group1 group1 = self.add_owner_to_group(group_identifier=str(group1.scim_id), user_identifier=str(user_scim_id)) + assert group1 extra_user = self.add_user( identifier=str(uuid4()), external_id="other external id", profiles={"test": self.test_profile} ) - group2 = self.add_group_with_member( + group2: ScimApiGroup | None = self.add_group_with_member( group_identifier=str(uuid4()), display_name="Group 2", user_identifier=str(user_scim_id) ) + assert group2 group2 = self.add_member_to_group(group_identifier=str(group2.scim_id), user_identifier=str(extra_user.scim_id)) + assert group2 assert len(group1.members) == 1 assert len(group1.owners) == 1 @@ -547,28 +579,33 @@ def test_delete_user_with_groups(self): response = self.client.delete(url=f"/Users/{db_user.scim_id}", headers=self.headers) self._assertResponse(response, status_code=204) # No content - db_user = self.userdb.get_user_by_scim_id(user_scim_id) + assert self.userdb + db_user = self.userdb.get_user_by_scim_id(str(user_scim_id)) assert db_user is None + assert self.groupdb group1 = self.groupdb.get_group_by_scim_id(str(group1.scim_id)) + assert group1 assert len(group1.graph.members) == 0 assert len(group1.graph.owners) == 0 group2 = self.groupdb.get_group_by_scim_id(str(group2.scim_id)) + assert group2 assert len(group2.graph.members) == 1 # check that the action resulted in an event in the database + assert self.eventdb events = self.eventdb.get_events_by_resource(SCIMResourceType.USER, user_scim_id) assert len(events) == 1 event = events[0] assert event.resource.external_id == external_id assert event.data["status"] == EventStatus.DELETED.value - def test_search_user_external_id(self): + def test_search_user_external_id(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile}) self._perform_search(search_filter=f'externalId eq "{db_user.external_id}"', expected_user=db_user) - def test_search_user_last_modified(self): + def test_search_user_last_modified(self) -> None: db_user1 = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) db_user2 = self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile}) self.assertGreater(db_user2.last_modified, db_user1.last_modified) @@ -583,14 +620,15 @@ def test_search_user_last_modified(self): search_filter=f'meta.lastModified gt "{db_user1.last_modified.isoformat()}"', expected_user=db_user2 ) - def test_search_user_profile_data(self): + def test_search_user_profile_data(self) -> None: db_user = self.add_user(identifier=str(uuid4()), external_id="test-id-1", profiles={"test": self.test_profile}) self.add_user(identifier=str(uuid4()), external_id="test-id-2", profiles={"test": self.test_profile2}) self._perform_search(search_filter='profiles.test.data.test_key eq "test_value"', expected_user=db_user) - def test_search_user_start_index(self): + def test_search_user_start_index(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -601,9 +639,10 @@ def test_search_user_start_index(self): expected_total_results=9, ) - def test_search_user_count(self): + def test_search_user_count(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -614,9 +653,10 @@ def test_search_user_count(self): expected_total_results=9, ) - def test_search_user_start_index_and_count(self): + def test_search_user_start_index_and_count(self) -> None: for i in range(9): self.add_user(identifier=str(uuid4()), external_id=f"test-id-{i}", profiles={"test": self.test_profile}) + assert self.userdb self.assertEqual(9, self.userdb.db_count()) last_modified = datetime.utcnow() - timedelta(hours=1) self._perform_search( @@ -628,7 +668,7 @@ def test_search_user_start_index_and_count(self): expected_total_results=9, ) - def test_create_and_update_user_with_linked_accounts(self): + def test_create_and_update_user_with_linked_accounts(self) -> None: """Test that creating a user and then updating it without changes only results in one event""" account = LinkedAccount(issuer="eduid.se", value="test@dev.eduid.se") _db_account = ScimApiLinkedAccount(issuer=account.issuer, value=account.value, parameters=account.parameters) @@ -636,6 +676,9 @@ def test_create_and_update_user_with_linked_accounts(self): SCIMSchema.NUTID_USER_V1.value: {"profiles": {}, "linked_accounts": [account.to_dict()]}, } result1 = self._create_user(req) + assert result1.parsed_response + + assert self.userdb self._assertResponse(result1.response, status_code=201) db_user = self.userdb.get_user_by_scim_id(str(result1.parsed_response.id)) @@ -652,6 +695,7 @@ def test_create_and_update_user_with_linked_accounts(self): # Update the user result2 = self._update_user(req, result1.parsed_response.id, result1.parsed_response.meta.version) + assert result2.parsed_response self._assertResponse(result2.response, status_code=200) db_user = self.userdb.get_user_by_scim_id(str(result2.parsed_response.id)) assert db_user @@ -663,7 +707,7 @@ def test_create_and_update_user_with_linked_accounts(self): # Verify the updated account made it into the database assert db_user.linked_accounts == [_db_account] - def test_create_user_with_invalid_linked_accounts_issuer(self): + def test_create_user_with_invalid_linked_accounts_issuer(self) -> None: """Test that creating a user with an invalid issuer and valid value fails""" account = LinkedAccount(issuer="NOT-eduid.se", value="test@dev.eduid.se") req = { @@ -674,7 +718,7 @@ def test_create_user_with_invalid_linked_accounts_issuer(self): result1 = self._create_user(req, expect_success=False) self._assertScimError(json=result1.response.json(), detail="Invalid nutid linked_accounts") - def test_create_user_with_invalid_linked_accounts_value(self): + def test_create_user_with_invalid_linked_accounts_value(self) -> None: """Test that creating a user with valid issuer and invalid value fails""" account = LinkedAccount(issuer="eduid.se", value="test@eduid.com") req = { @@ -685,7 +729,7 @@ def test_create_user_with_invalid_linked_accounts_value(self): result1 = self._create_user(req, expect_success=False) self._assertScimError(json=result1.response.json(), detail="Invalid nutid linked_accounts") - def test_update_user_set_linked_accounts(self): + def test_update_user_set_linked_accounts(self) -> None: db_account1 = ScimApiLinkedAccount(issuer="eduid.se", value="test1@dev.eduid.se") account2 = LinkedAccount(issuer="eduid.se", value="test2@eduid.se", parameters={"mfa_stepup": True}) db_user = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) @@ -698,12 +742,13 @@ def test_update_user_set_linked_accounts(self): self._assertResponse(result.response) self._assertUserUpdateSuccess(req, result.response, db_user) - def test_update_user_set_linked_accounts2(self): + def test_update_user_set_linked_accounts2(self) -> None: """Test updating linked accounts sorted 'wrong'""" db_account1 = ScimApiLinkedAccount(issuer="eduid.se", value="test1@dev.eduid.se") account1 = LinkedAccount(issuer=db_account1.issuer, value=db_account1.value) account2 = LinkedAccount(issuer="eduid.se", value="test2@eduid.se", parameters={"mfa_stepup": True}) - db_user = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) + db_user: ScimApiUser | None = self.add_user(identifier=str(uuid4()), linked_accounts=[db_account1]) + assert db_user req = { "schemas": [SCIMSchema.CORE_20_USER.value, SCIMSchema.NUTID_USER_V1.value], "id": str(db_user.scim_id), @@ -715,6 +760,7 @@ def test_update_user_set_linked_accounts2(self): result = self._update_user(req, db_user.scim_id, version=db_user.version) self._assertResponse(result.response) self._assertUserUpdateSuccess(req, result.response, db_user) + assert self.userdb db_user = self.userdb.get_user_by_scim_id(str(db_user.scim_id)) assert db_user db_account2 = ScimApiLinkedAccount(issuer=account2.issuer, value=account2.value, parameters=account2.parameters) @@ -801,7 +847,8 @@ def setUp(self) -> None: for user in self.users[1:]: self.add_member_to_group(group_identifier=str(self.group.scim_id), user_identifier=str(user.scim_id)) - async def test_delete_user_with_groups(self): + async def test_delete_user_with_groups(self) -> None: + assert self.groupdb group = self.groupdb.get_group_by_scim_id(str(self.group.scim_id)) assert group is not None # please mypy assert len(group.members) == self.user_count @@ -820,8 +867,10 @@ async def test_delete_user_with_groups(self): for task in tasks: self._assertResponse(task.result(), status_code=204) # No content + assert self.userdb for user in self.users[: self.user_count // 2]: - assert self.userdb.get_user_by_scim_id(user.scim_id) is None + assert self.userdb.get_user_by_scim_id(str(user.scim_id)) is None group = self.groupdb.get_group_by_scim_id(str(group.scim_id)) + assert group assert len(group.graph.members) == self.user_count // 2 diff --git a/src/eduid/scimapi/tests/test_search_filter.py b/src/eduid/scimapi/tests/test_search_filter.py index 5e83bbf80..8e2c9da80 100644 --- a/src/eduid/scimapi/tests/test_search_filter.py +++ b/src/eduid/scimapi/tests/test_search_filter.py @@ -6,7 +6,7 @@ class TestSearchFilter(unittest.TestCase): - def test_lastmodified(self): + def test_lastmodified(self) -> None: now = datetime.utcnow() filter = f'meta.lastModified gt "{now.isoformat()}"' sf = parse_search_filter(filter) @@ -14,7 +14,7 @@ def test_lastmodified(self): self.assertEqual(sf.op, "gt") self.assertEqual(sf.val, now.isoformat()) - def test_lastmodified_with_tz(self): + def test_lastmodified_with_tz(self) -> None: nowstr = "2020-05-05T09:13:43.916000+00:00" filter = f'meta.lastModified gt "{nowstr}"' sf = parse_search_filter(filter) @@ -22,21 +22,21 @@ def test_lastmodified_with_tz(self): self.assertEqual(sf.op, "gt") self.assertEqual(sf.val, nowstr) - def test_str(self): + def test_str(self) -> None: filter = 'foo eq "123"' sf = parse_search_filter(filter) self.assertEqual(sf.attr, "foo") self.assertEqual(sf.op, "eq") self.assertEqual(sf.val, "123") - def test_int(self): + def test_int(self) -> None: filter = "foo eq 123" sf = parse_search_filter(filter) self.assertEqual(sf.attr, "foo") self.assertEqual(sf.op, "eq") self.assertEqual(sf.val, 123) - def test_not_printable(self): + def test_not_printable(self) -> None: filter = "foo eq 12\u00093" with self.assertRaises(BadRequest): parse_search_filter(filter) diff --git a/src/eduid/scimapi/utils.py b/src/eduid/scimapi/utils.py index c9a05e23c..54134a22d 100644 --- a/src/eduid/scimapi/utils.py +++ b/src/eduid/scimapi/utils.py @@ -48,7 +48,7 @@ def filter_none(x: Filtered) -> Filtered: return x -def get_unique_hash(): +def get_unique_hash() -> str: return str(uuid4()) @@ -61,7 +61,7 @@ def load_jwks(config: ScimApiConfig) -> jwk.JWKSet: return jwks -def retryable_db_write(func: Callable): +def retryable_db_write(func: Callable) -> Callable: @functools.wraps(func) def wrapper_run_func(*args: Any, **kwargs: Any): max_retries = 10 diff --git a/src/eduid/userdb/db/async_db.py b/src/eduid/userdb/db/async_db.py index 2247d04a3..4134ca317 100644 --- a/src/eduid/userdb/db/async_db.py +++ b/src/eduid/userdb/db/async_db.py @@ -80,7 +80,7 @@ def get_collection(self, collection: str, database_name: str | None = None) -> A _db = self.get_database(database_name) return _db[collection] - async def is_healthy(self): + async def is_healthy(self) -> bool: """ From mongo_client.py: Starting with version 3.0 the :class:`MongoClient` @@ -111,7 +111,7 @@ async def is_healthy(self): logger.error(f"{self} not healthy: {e}") return False - async def close(self): + async def close(self) -> None: self._client.close() diff --git a/src/eduid/userdb/db/sync_db.py b/src/eduid/userdb/db/sync_db.py index a34dfe9c9..6bf4aaa2d 100644 --- a/src/eduid/userdb/db/sync_db.py +++ b/src/eduid/userdb/db/sync_db.py @@ -88,7 +88,7 @@ def get_collection( _db = self.get_database(database_name) return _db[collection] - def is_healthy(self): + def is_healthy(self) -> bool: """ From mongo_client.py: Starting with version 3.0 the :class:`MongoClient` @@ -119,7 +119,7 @@ def is_healthy(self): logger.error(f"{self} not healthy: {e}") return False - def close(self): + def close(self) -> None: self._client.close() diff --git a/src/eduid/userdb/element.py b/src/eduid/userdb/element.py index eb7e5136f..5d840eaca 100644 --- a/src/eduid/userdb/element.py +++ b/src/eduid/userdb/element.py @@ -349,7 +349,7 @@ def find(self, key: ElementKey | str | None) -> ListElement | None: raise EduIDUserDBError("More than one element found") return res[0] - def add(self, element: ListElement): + def add(self, element: ListElement) -> None: """ Add an element to the list. diff --git a/src/eduid/userdb/event.py b/src/eduid/userdb/event.py index b995839a7..728f4e8ca 100644 --- a/src/eduid/userdb/event.py +++ b/src/eduid/userdb/event.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 from pydantic import Field @@ -11,6 +11,9 @@ TEventSubclass = TypeVar("TEventSubclass", bound="Event") +if TYPE_CHECKING: + from eduid.userdb.tou import ToUEvent + class Event(Element): """ """ @@ -69,7 +72,7 @@ class EventList(ElementList[ListElement], Generic[ListElement], ABC): pass -def event_from_dict(data: dict[str, Any]): +def event_from_dict(data: dict[str, Any]) -> ToUEvent: """ Create an Event instance (probably really a subclass of Event) from a dict. diff --git a/src/eduid/userdb/locked_identity.py b/src/eduid/userdb/locked_identity.py index 31e648a8a..05e05d828 100644 --- a/src/eduid/userdb/locked_identity.py +++ b/src/eduid/userdb/locked_identity.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, NoReturn from pydantic import field_validator @@ -35,7 +35,7 @@ def replace(self, element: IdentityElement) -> None: self.add(element=element) return None - def remove(self, key: ElementKey): + def remove(self, key: ElementKey) -> NoReturn: """ Override remove method as an element should be set once, remove never. """ diff --git a/src/eduid/userdb/mail.py b/src/eduid/userdb/mail.py index 66b3fa618..8c20b6d8f 100644 --- a/src/eduid/userdb/mail.py +++ b/src/eduid/userdb/mail.py @@ -53,7 +53,7 @@ def from_list_of_dicts(cls: type[MailAddressList], items: list[dict[str, Any]]) return cls(elements=[MailAddress.from_dict(this) for this in items]) -def address_from_dict(data: dict[str, Any]): +def address_from_dict(data: dict[str, Any]) -> MailAddress: """ Create a MailAddress instance from a dict. diff --git a/src/eduid/userdb/meta.py b/src/eduid/userdb/meta.py index dd076aab2..b963a9fc0 100644 --- a/src/eduid/userdb/meta.py +++ b/src/eduid/userdb/meta.py @@ -24,5 +24,5 @@ class Meta(BaseModel): is_in_database: Annotated[bool, Field(exclude=True)] = False # this is set to True when userdb loads the object model_config = ConfigDict(arbitrary_types_allowed=True) - def new_version(self): + def new_version(self) -> None: self.version = ObjectId() diff --git a/src/eduid/userdb/reset_password/state.py b/src/eduid/userdb/reset_password/state.py index 695e23d3d..9760fb5e5 100644 --- a/src/eduid/userdb/reset_password/state.py +++ b/src/eduid/userdb/reset_password/state.py @@ -88,7 +88,7 @@ def __post_init__(self): self.method = "email" self.email_code = CodeElement.parse(application="security", code_or_element=self.email_code) - def to_dict(self): + def to_dict(self) -> TUserDbDocument: res = super().to_dict() res["email_code"] = self.email_code.to_dict() return res diff --git a/src/eduid/userdb/scimapi/invitedb.py b/src/eduid/userdb/scimapi/invitedb.py index b72e58475..dfab01047 100644 --- a/src/eduid/userdb/scimapi/invitedb.py +++ b/src/eduid/userdb/scimapi/invitedb.py @@ -107,7 +107,7 @@ def save(self, invite: ScimApiInvite) -> bool: return result.acknowledged - def remove(self, invite: ScimApiInvite): + def remove(self, invite: ScimApiInvite) -> bool: return self.remove_document(invite.invite_id) def get_invite_by_scim_id(self, scim_id: str) -> ScimApiInvite | None: diff --git a/src/eduid/userdb/scimapi/userdb.py b/src/eduid/userdb/scimapi/userdb.py index c41103246..9580c397c 100644 --- a/src/eduid/userdb/scimapi/userdb.py +++ b/src/eduid/userdb/scimapi/userdb.py @@ -40,7 +40,7 @@ class ScimApiUser(ScimApiResourceBase): linked_accounts: list[ScimApiLinkedAccount] = field(default_factory=list) @property - def etag(self): + def etag(self) -> str: return f'W/"{self.version}"' def to_dict(self) -> TUserDbDocument: @@ -128,7 +128,7 @@ def save(self, user: ScimApiUser) -> None: return None - def remove(self, user: ScimApiUser): + def remove(self, user: ScimApiUser) -> bool: return self.remove_document(user.user_id) def get_user_by_scim_id(self, scim_id: str) -> ScimApiUser | None: @@ -157,7 +157,7 @@ def get_user_by_profile_data( profile: str, key: str, operator: str, - value: datetime, + value: str | int, limit: int | None = None, skip: int | None = None, ) -> tuple[list[ScimApiUser], int]: diff --git a/src/eduid/userdb/testing/__init__.py b/src/eduid/userdb/testing/__init__.py index c5689022d..ae569a4ba 100644 --- a/src/eduid/userdb/testing/__init__.py +++ b/src/eduid/userdb/testing/__init__.py @@ -57,10 +57,10 @@ def conn(self) -> pymongo.MongoClient[TUserDbDocument]: return self._conn @property - def uri(self): + def uri(self) -> str: return f"mongodb://localhost:{self.port}" - def shutdown(self): + def shutdown(self) -> None: if self._conn: logger.info(f"Closing connection {self._conn}") self._conn.close() @@ -82,7 +82,7 @@ class MongoTestCase(unittest.TestCase): A test can access the port using the attribute `port` """ - def setUp(self, am_users: list[User] | None = None): + def setUp(self, am_users: list[User] | None = None) -> None: """ Test case initialization. :return: @@ -139,7 +139,7 @@ def _reset_databases(self): self.tmp_db.conn.drop_database(db_name) self.amdb._drop_whole_collection() - def tearDown(self): + def tearDown(self) -> None: for userdoc in self.amdb._get_all_docs(): assert User.from_dict(data=userdoc) self._reset_databases() @@ -156,7 +156,7 @@ class AsyncMongoTestCase(unittest.IsolatedAsyncioTestCase): A test can access the port using the attribute `port` """ - def setUp(self, *args: list[Any], **kwargs: dict[str, Any]): + def setUp(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: """ Test case initialization. :return: @@ -205,6 +205,6 @@ def _reset_databases(self): if db_name not in ["local", "admin", "config"]: # Do not drop mongo internal dbs self.tmp_db.conn.drop_database(db_name) - def tearDown(self): + def tearDown(self) -> None: self._reset_databases() super().tearDown() diff --git a/src/eduid/userdb/testing/temp_instance.py b/src/eduid/userdb/testing/temp_instance.py index faa69f3aa..3b1b9893a 100644 --- a/src/eduid/userdb/testing/temp_instance.py +++ b/src/eduid/userdb/testing/temp_instance.py @@ -106,7 +106,7 @@ def output(self) -> str: _output = "".join(fd.readlines()) return _output - def shutdown(self): + def shutdown(self) -> None: logger.debug(f"{self} output at shutdown:\n{self.output}") if self._process: self._process.terminate() diff --git a/src/eduid/userdb/tests/test_app_user.py b/src/eduid/userdb/tests/test_app_user.py index 155d70325..6e6e0f752 100644 --- a/src/eduid/userdb/tests/test_app_user.py +++ b/src/eduid/userdb/tests/test_app_user.py @@ -12,41 +12,43 @@ class TestAppUser(TestCase): users: UserFixtures user_data: TUserDbDocument - def setUp(self): + def setUp(self) -> None: _users = UserFixtures() self.user = _users.new_user_example self.user_data = self.user.to_dict() - def test_proper_user(self): + def test_proper_user(self) -> None: user = User.from_dict(data=self.user_data) self.assertEqual(user.user_id, self.user_data["_id"]) self.assertEqual(user.eppn, self.user_data["eduPersonPrincipalName"]) - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: user = User(user_id=self.user.user_id, eppn=self.user.eppn, credentials=self.user.credentials) self.assertEqual(user.user_id, self.user.user_id) self.assertEqual(user.eppn, self.user.eppn) - def test_missing_id(self): + def test_missing_id(self) -> None: user = User(eppn=self.user.eppn, credentials=self.user.credentials) self.assertNotEqual(user.user_id, self.user.user_id) - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: _data = self.user.to_dict() _data.pop("eduPersonPrincipalName") with self.assertRaises(ValidationError): User.from_dict(_data) - def test_identities_created_ts_true(self): + def test_identities_created_ts_true(self) -> None: _data = self.user.to_dict() _data["identities"][0]["created_ts"] = True user = User.from_dict(_data) identity = user.identities.find(_data["identities"][0]["identity_type"]) + assert identity assert isinstance(identity.created_ts, datetime) is True - def test_locked_identity_created_ts_true(self): + def test_locked_identity_created_ts_true(self) -> None: _data = self.user.to_dict() _data["locked_identity"][0]["created_ts"] = True user = User.from_dict(_data) locked_identity = user.locked_identity.find(_data["locked_identity"][0]["identity_type"]) + assert locked_identity assert isinstance(locked_identity.created_ts, datetime) is True diff --git a/src/eduid/userdb/tests/test_async_db.py b/src/eduid/userdb/tests/test_async_db.py index b23dc41c5..f92f9fa3f 100644 --- a/src/eduid/userdb/tests/test_async_db.py +++ b/src/eduid/userdb/tests/test_async_db.py @@ -6,7 +6,7 @@ class TestAsyncMongoDB(IsolatedAsyncioTestCase): - async def test_full_uri(self): + async def test_full_uri(self) -> None: # full specified uri uri = "mongodb://db.example.com:1111/testdb" mdb = AsyncMongoDB(uri, db_name="testdb") @@ -17,7 +17,7 @@ async def test_full_uri(self): self.assertEqual(mdb._db_uri, uri) self.assertEqual(mdb._database_name, "testdb") - async def test_uri_without_path_component(self): + async def test_uri_without_path_component(self) -> None: uri = "mongodb://db.example.com:1111" mdb = AsyncMongoDB(uri, db_name="testdb") database = mdb.get_database() @@ -25,7 +25,7 @@ async def test_uri_without_path_component(self): self.assertEqual(mdb._db_uri, uri + "/testdb") self.assertEqual(mdb._database_name, "testdb") - async def test_uri_without_port(self): + async def test_uri_without_port(self) -> None: uri = "mongodb://db.example.com/" mdb = AsyncMongoDB(uri) self.assertEqual(mdb._db_uri, uri) @@ -33,7 +33,7 @@ async def test_uri_without_port(self): assert database is not None self.assertEqual(mdb.sanitized_uri, "mongodb://db.example.com/") - async def test_uri_with_username_and_password(self): + async def test_uri_with_username_and_password(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:1111/testdb" mdb = AsyncMongoDB(uri, db_name="testdb") conn = mdb.get_connection() @@ -47,7 +47,7 @@ async def test_uri_with_username_and_password(self): mdb.__repr__(), "" ) - async def test_uri_with_replicaset(self): + async def test_uri_with_replicaset(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com,db2.example.com:27017,db3.example.com:1234/?replicaSet=rs9" mdb = AsyncMongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9") @@ -56,7 +56,7 @@ async def test_uri_with_replicaset(self): "mongodb://john:s3cr3t@db.example.com,db2.example.com,db3.example.com:1234/testdb?replicaset=rs9", ) - async def test_uri_with_options(self): + async def test_uri_with_options(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:27017/?ssl=true&replicaSet=rs9" mdb = AsyncMongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9&tls=true") @@ -71,25 +71,25 @@ async def asyncSetUp(self) -> None: self.num_objs = 10 await self.db.collection.insert_many([{"x": i} for i in range(self.num_objs)]) - async def test_db_count(self): + async def test_db_count(self) -> None: self.assertEqual(self.num_objs, await self.db.db_count()) - async def test_db_count_limit(self): + async def test_db_count_limit(self) -> None: self.assertEqual(1, await self.db.db_count(limit=1)) self.assertEqual(2, await self.db.db_count(limit=2)) - async def test_db_count_spec(self): + async def test_db_count_spec(self) -> None: self.assertEqual(1, await self.db.db_count(spec={"x": 3})) - async def test_get_documents_by_filter_skip(self): + async def test_get_documents_by_filter_skip(self) -> None: docs = await self.db._get_documents_by_filter(spec={}, skip=2) self.assertEqual(8, len(docs)) - async def test_get_documents_by_filter_limit(self): + async def test_get_documents_by_filter_limit(self) -> None: docs = await self.db._get_documents_by_filter(spec={}, limit=1) self.assertEqual(1, len(docs)) - async def test_get_documents_by_aggregate(self): + async def test_get_documents_by_aggregate(self) -> None: match = { "x": 3, } diff --git a/src/eduid/userdb/tests/test_credentials.py b/src/eduid/userdb/tests/test_credentials.py index 28bb015e0..4aa0297e9 100644 --- a/src/eduid/userdb/tests/test_credentials.py +++ b/src/eduid/userdb/tests/test_credentials.py @@ -51,7 +51,7 @@ def _keyid(key: dict[str, str]): class TestCredentialList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # make pytest always show full diffs self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) @@ -59,14 +59,14 @@ def setUp(self): self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) self.four = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict, _four_dict]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) self.assertEqual(4, len(self.four.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) expected = [_one_dict] @@ -74,22 +74,23 @@ def test_to_list_of_dicts(self): assert obtained == expected, "Credential list with one password not as expected" - def test_find(self): + def test_find(self) -> None: match = self.two.find("222222222222222222222222") - self.assertIsInstance(match, Password) + assert isinstance(match, Password) self.assertEqual(match.credential_id, "222222222222222222222222") self.assertEqual(match.salt, "secondPasswordElement") self.assertEqual(match.created_by, "test") - def test_filter(self): + def test_filter(self) -> None: match = self.four.filter(U2F) assert len(match) == 1 token = match[0] assert token.key == _keyid(_four_dict) assert token.public_key == "foo" - def test_add(self): + def test_add(self) -> None: second = self.two.find(str(ObjectId("222222222222222222222222"))) + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -97,8 +98,9 @@ def test_add(self): assert obtained == expected, "List of credentials with added credential different than expected" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: dup = self.two.find(str(ObjectId("222222222222222222222222"))) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -113,15 +115,15 @@ def test_add_duplicate(self): ], ), f"Wrong error message: {exc_info.value.errors()}" - def test_add_password(self): - this = CredentialList.from_list_of_dicts([_one_dict, _two_dict] + [_three_dict]) + def test_add_password(self) -> None: + this = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) expected = self.three.to_list_of_dicts() obtained = this.to_list_of_dicts() assert obtained == expected, "List of credentials with added password different than expected" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey(str(ObjectId("333333333333333333333333")))) now_two = self.three @@ -130,11 +132,11 @@ def test_remove(self): assert obtained == expected, "List of credentials with removed credential different than expected" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey(str(ObjectId("55002741d00690878ae9b603")))) - def test_generated(self): + def test_generated(self) -> None: match = self.three.find("222222222222222222222222") assert isinstance(match, Password) assert match.is_generated is False @@ -142,7 +144,7 @@ def test_generated(self): assert isinstance(match, Password) assert match.is_generated is True - def test_external_credential(self): + def test_external_credential(self) -> None: _id = ElementKey(str(ObjectId())) # A SwedenConnectCredential as stored in the database data = {"framework": "SWECONN", "level": "loa3", "credential_id": _id} diff --git a/src/eduid/userdb/tests/test_db.py b/src/eduid/userdb/tests/test_db.py index fef175352..6ca516097 100644 --- a/src/eduid/userdb/tests/test_db.py +++ b/src/eduid/userdb/tests/test_db.py @@ -9,7 +9,7 @@ class TestMongoDB(TestCase): - def test_full_uri(self): + def test_full_uri(self) -> None: # full specified uri uri = "mongodb://db.example.com:1111/testdb" mdb = db.MongoDB(uri, db_name="testdb") @@ -19,21 +19,21 @@ def test_full_uri(self): self.assertEqual(mdb._db_uri, uri) self.assertEqual(mdb._database_name, "testdb") - def test_uri_without_path_component(self): + def test_uri_without_path_component(self) -> None: uri = "mongodb://db.example.com:1111" mdb = db.MongoDB(uri, db_name="testdb") mdb.get_database() self.assertEqual(mdb._db_uri, uri + "/testdb") self.assertEqual(mdb._database_name, "testdb") - def test_uri_without_port(self): + def test_uri_without_port(self) -> None: uri = "mongodb://db.example.com/" mdb = db.MongoDB(uri) self.assertEqual(mdb._db_uri, uri) mdb.get_database("testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://db.example.com/") - def test_uri_with_username_and_password(self): + def test_uri_with_username_and_password(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:1111/testdb" mdb = db.MongoDB(uri, db_name="testdb") conn = mdb.get_connection() @@ -45,7 +45,7 @@ def test_uri_with_username_and_password(self): self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com:1111/testdb") self.assertEqual(mdb.__repr__(), "") - def test_uri_with_replicaset(self): + def test_uri_with_replicaset(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com,db2.example.com:27017,db3.example.com:1234/?replicaSet=rs9" mdb = db.MongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9") @@ -54,53 +54,53 @@ def test_uri_with_replicaset(self): "mongodb://john:s3cr3t@db.example.com,db2.example.com,db3.example.com:1234/testdb?replicaset=rs9", ) - def test_uri_with_options(self): + def test_uri_with_options(self) -> None: uri = "mongodb://john:s3cr3t@db.example.com:27017/?ssl=true&replicaSet=rs9" mdb = db.MongoDB(uri, db_name="testdb") self.assertEqual(mdb.sanitized_uri, "mongodb://john:secret@db.example.com/testdb?replicaset=rs9&tls=true") class TestDB(MongoTestCase): - def setUp(self): + def setUp(self) -> None: # type: ignore[override] _users = UserFixtures() self._am_users = [_users.new_unverified_user_example, _users.mocked_user_standard_2, _users.new_user_example] super().setUp(am_users=self._am_users) - def test_db_count(self): + def test_db_count(self) -> None: self.assertEqual(len(self._am_users), self.amdb.db_count()) - def test_db_count_limit(self): + def test_db_count_limit(self) -> None: self.assertEqual(1, self.amdb.db_count(limit=1)) self.assertEqual(2, self.amdb.db_count(limit=2)) - def test_db_count_spec(self): + def test_db_count_spec(self) -> None: self.assertEqual(1, self.amdb.db_count(spec={"_id": ObjectId("012345678901234567890123")})) - def test_get_documents_by_filter_skip(self): + def test_get_documents_by_filter_skip(self) -> None: docs = self.amdb._get_documents_by_filter(spec={}, skip=2) self.assertEqual(1, len(docs)) - def test_get_documents_by_filter_limit(self): + def test_get_documents_by_filter_limit(self) -> None: docs = self.amdb._get_documents_by_filter(spec={}, limit=1) self.assertEqual(1, len(docs)) - def test_get_verified_users_count_NIN(self): + def test_get_verified_users_count_NIN(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.NIN) assert count == 1 - def test_get_verified_users_count_EIDAS(self): + def test_get_verified_users_count_EIDAS(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.EIDAS) assert count == 1 - def test_get_verified_users_count_SVIPE(self): + def test_get_verified_users_count_SVIPE(self) -> None: count = self.amdb.get_verified_users_count(identity_type=IdentityType.SVIPE) assert count == 1 - def test_get_verified_users_count_None(self): + def test_get_verified_users_count_None(self) -> None: count = self.amdb.get_verified_users_count() assert count == 1 - def test_get_documents_by_aggregate(self): + def test_get_documents_by_aggregate(self) -> None: match = { "eduPersonPrincipalName": "hubba-bubba", } diff --git a/src/eduid/userdb/tests/test_element.py b/src/eduid/userdb/tests/test_element.py index 7f2c83a28..96d13166d 100644 --- a/src/eduid/userdb/tests/test_element.py +++ b/src/eduid/userdb/tests/test_element.py @@ -5,14 +5,14 @@ class TestElements(TestCase): - def test_create_element(self): + def test_create_element(self) -> None: elem = Element(created_by="test") assert elem.created_by == "test" assert isinstance(elem.created_ts, datetime) assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_created_ts(self): + def test_create_element_with_created_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", created_ts=now) @@ -20,7 +20,7 @@ def test_create_element_with_created_ts(self): assert elem.created_ts == now assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_modified_ts(self): + def test_create_element_with_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now) @@ -28,7 +28,7 @@ def test_create_element_with_modified_ts(self): assert elem.modified_ts == now assert isinstance(elem.modified_ts, datetime) - def test_create_element_with_created_and_modified_ts(self): + def test_create_element_with_created_and_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now, created_ts=now) @@ -36,7 +36,7 @@ def test_create_element_with_created_and_modified_ts(self): assert elem.created_ts == now assert elem.modified_ts == now - def test_element_reset_modified_ts(self): + def test_element_reset_modified_ts(self) -> None: now = datetime.utcnow() elem = Element(created_by="test", modified_ts=now, created_ts=now) @@ -47,7 +47,7 @@ def test_element_reset_modified_ts(self): class TestVerifiedElements(TestCase): - def test_create_verified_element(self): + def test_create_verified_element(self) -> None: elem = VerifiedElement(created_by="test") assert elem.created_by == "test" @@ -58,7 +58,7 @@ def test_create_verified_element(self): assert elem.verified_by is None assert elem.verified_ts is None - def test_modify_verified_element(self): + def test_modify_verified_element(self) -> None: elem = VerifiedElement(created_by="test") now = datetime.utcnow() @@ -74,7 +74,7 @@ def test_modify_verified_element(self): assert elem.verified_by == "test" assert elem.verified_ts == now - def test_create_full_verified_element(self): + def test_create_full_verified_element(self) -> None: now = datetime.utcnow() elem = VerifiedElement( @@ -91,7 +91,7 @@ def test_create_full_verified_element(self): class TestPrimaryElements(TestCase): - def test_create_primary_element(self): + def test_create_primary_element(self) -> None: elem = PrimaryElement(created_by="test") assert elem.created_by == "test" @@ -104,7 +104,7 @@ def test_create_primary_element(self): assert elem.is_primary is False - def test_modify_primary_element(self): + def test_modify_primary_element(self) -> None: elem = PrimaryElement(created_by="test") now = datetime.utcnow() @@ -124,7 +124,7 @@ def test_modify_primary_element(self): assert elem.is_primary is True - def test_create_full_primary_element(self): + def test_create_full_primary_element(self) -> None: now = datetime.utcnow() elem = PrimaryElement( @@ -147,7 +147,7 @@ def test_create_full_primary_element(self): assert elem.is_primary is True - def test_unverify_primary_element(self): + def test_unverify_primary_element(self) -> None: now = datetime.utcnow() elem = PrimaryElement( diff --git a/src/eduid/userdb/tests/test_event.py b/src/eduid/userdb/tests/test_event.py index 8dfa05196..413a19ad4 100644 --- a/src/eduid/userdb/tests/test_event.py +++ b/src/eduid/userdb/tests/test_event.py @@ -13,6 +13,7 @@ import eduid.userdb.exceptions from eduid.common.testing_base import normalised_data from eduid.userdb import PhoneNumber +from eduid.userdb.element import ElementKey from eduid.userdb.event import EventList from eduid.userdb.tou import ToUEvent @@ -55,42 +56,43 @@ def from_list_of_dicts(cls: type[SomeEventList], items: list[dict[str, Any]]) -> class TestEventList(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = SomeEventList() self.one = SomeEventList.from_list_of_dicts([_one_dict]) self.two = SomeEventList.from_list_of_dicts([_one_dict, _two_dict]) self.three = SomeEventList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): SomeEventList(elements="bad input data") with pytest.raises(ValidationError): SomeEventList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) _one_dict_copy = deepcopy(_one_dict) # Update id to event_id before comparing dicts _one_dict_copy["event_id"] = _one_dict_copy.pop("id") self.assertEqual([_one_dict_copy], self.one.to_list_of_dicts()) - def test_find(self): + def test_find(self) -> None: match = self.one.find(self.one.to_list()[0].key) + assert match self.assertIsInstance(match, ToUEvent) self.assertEqual(match.version, _one_dict["version"]) - def test_add(self): + def test_add(self) -> None: second = self.two.to_list()[-1] self.one.add(second) self.assertEqual(self.one.to_list_of_dicts(), self.two.to_list_of_dicts()) - def test_add_duplicate_key(self): + def test_add_duplicate_key(self) -> None: data = deepcopy(_two_dict) data["version"] = "other version" dup = ToUEvent.from_dict(data) @@ -108,26 +110,26 @@ def test_add_duplicate_key(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_add_event(self): + def test_add_event(self) -> None: third = self.three.to_list_of_dicts()[-1] this = SomeEventList.from_list_of_dicts([_one_dict, _two_dict, third]) self.assertEqual(this.to_list_of_dicts(), self.three.to_list_of_dicts()) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError): - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] - def test_remove(self): + def test_remove(self) -> None: self.three.remove(self.three.to_list()[-1].key) now_two = self.three self.assertEqual(self.two.to_list_of_dicts(), now_two.to_list_of_dicts()) - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("+46709999999") + self.one.remove(ElementKey("+46709999999")) - def test_unknown_event_type(self): + def test_unknown_event_type(self) -> None: e1 = { "event_type": "unknown_event", "id": str(bson.ObjectId()), @@ -151,7 +153,7 @@ def test_unknown_event_type(self): ], ), f"Wrong error message: {exc_info.value.errors()}" - def test_modified_ts_addition(self): + def test_modified_ts_addition(self) -> None: _event_no_modified_ts = { "event_type": "tou_event", "version": "1", @@ -171,10 +173,10 @@ def test_modified_ts_addition(self): else: self.assertIsInstance(event["modified_ts"], datetime) assert event["modified_ts"] == event["created_ts"] - for event in el.to_list(): - self.assertIsInstance(event.modified_ts, datetime) + for event2 in el.to_list(): + self.assertIsInstance(event2.modified_ts, datetime) - def test_update_modified_ts(self): + def test_update_modified_ts(self) -> None: _event_modified_ts = { "event_type": "tou_event", "version": "1", diff --git a/src/eduid/userdb/tests/test_exceptions.py b/src/eduid/userdb/tests/test_exceptions.py index 997489d17..0a01a2999 100644 --- a/src/eduid/userdb/tests/test_exceptions.py +++ b/src/eduid/userdb/tests/test_exceptions.py @@ -6,6 +6,6 @@ class TestEduIDUserDBError(TestCase): - def test_repr(self): + def test_repr(self) -> None: ex = eduid.userdb.exceptions.EduIDUserDBError("test") self.assertIsInstance(str(ex), str) diff --git a/src/eduid/userdb/tests/test_group_management.py b/src/eduid/userdb/tests/test_group_management.py index bdf349f6b..b74abfdae 100644 --- a/src/eduid/userdb/tests/test_group_management.py +++ b/src/eduid/userdb/tests/test_group_management.py @@ -13,12 +13,12 @@ class TestResetGroupInviteStateDB(MongoTestCase): user: User - def setUp(self): + def setUp(self) -> None: # type: ignore[override] super().setUp() self.user = UserFixtures().mocked_user_standard self.invite_state_db = GroupManagementInviteStateDB(self.tmp_db.uri) - def test_invite_state(self): + def test_invite_state(self) -> None: # Member group_scim_id = str(uuid4()) invite_state = GroupInviteState( @@ -31,6 +31,7 @@ def test_invite_state(self): invite = self.invite_state_db.get_state( group_scim_id=group_scim_id, email_address="johnsmith@example.com", role=GroupRole.MEMBER ) + assert invite self.assertEqual(group_scim_id, invite.group_scim_id) self.assertEqual("johnsmith@example.com", invite.email_address) self.assertEqual(GroupRole.MEMBER, invite.role) @@ -47,11 +48,12 @@ def test_invite_state(self): invite = self.invite_state_db.get_state( group_scim_id=group_scim_id, email_address="johnsmith@example.com", role=GroupRole.OWNER ) + assert invite self.assertEqual(group_scim_id, invite.group_scim_id) self.assertEqual("johnsmith@example.com", invite.email_address) self.assertEqual(GroupRole.OWNER, invite.role) - def test_save_duplicate(self): + def test_save_duplicate(self) -> None: group_scim_id = str(uuid4()) invite_state1 = GroupInviteState( group_scim_id=group_scim_id, diff --git a/src/eduid/userdb/tests/test_identities.py b/src/eduid/userdb/tests/test_identities.py index 31da98eeb..2769c9fe1 100644 --- a/src/eduid/userdb/tests/test_identities.py +++ b/src/eduid/userdb/tests/test_identities.py @@ -48,37 +48,39 @@ class TestIdentityList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # Make pytest show full diffs self.empty = IdentityList() self.one = IdentityList.from_list_of_dicts([_one_dict]) self.two = IdentityList.from_list_of_dicts([_one_dict, _two_dict]) self.three = IdentityList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): IdentityList(elements="bad input data") with pytest.raises(ValidationError): IdentityList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("nin") - self.assertIsInstance(match, NinIdentity) + assert match + assert isinstance(match, NinIdentity) self.assertEqual(match.number, "197801011234") assert match.is_verified is True assert match.verified_ts is None - def test_add(self): + def test_add(self) -> None: second = self.two.find("svipe") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -105,8 +107,9 @@ def test_add_duplicate(self) -> None: ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_nin(self): + def test_add_nin(self) -> None: third = self.three.find("eidas") + assert third this = IdentityList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -114,13 +117,14 @@ def test_add_nin(self): assert normalised_data(obtained) == normalised_data(expected), "List with added nin has unexpected data" - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: """Test adding a phone number to the nin-list. Specifically phone, since pydantic can coerce it into a nin since they both have the 'number' field. """ new = PhoneNumber(number="+4612345678") + assert new with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -132,7 +136,7 @@ def test_add_wrong_type(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey("eidas")) now_two = self.three @@ -141,26 +145,27 @@ def test_remove(self): assert normalised_data(obtained) == normalised_data(expected), "List with removed NIN has unexpected data" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey("+46709999999")) class TestIdentity(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = IdentityList() self.one = IdentityList.from_list_of_dicts([_one_dict]) self.two = IdentityList.from_list_of_dicts([_one_dict, _two_dict]) self.three = IdentityList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the Nin. """ nin = self.two.nin + assert nin self.assertEqual(IdentityType.NIN.value, nin.key) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -168,58 +173,67 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(IdentityList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("nin") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("svipe") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("eidas") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("nin") + assert this now = utc_now() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("svipe") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_modify_created_by(self): + def test_modify_created_by(self) -> None: this = self.three.find("eidas") + assert this this.created_by = "unit test" assert this.created_by == "unit test" - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("nin") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_ts_bool(self): + def test_ts_bool(self) -> None: # check that we can't set created_ts or modified_ts to a bool but that we # can read those from db to fix them this = self.three.find("nin") + assert this with self.assertRaises(ValidationError): - this.created_ts = True + this.created_ts = True # type: ignore[assignment] with self.assertRaises(ValidationError): - this.modified_ts = True + this.modified_ts = True # type: ignore[assignment] this_dict = this.to_dict() this_dict["created_ts"] = True this_dict["modified_ts"] = True assert NinIdentity.from_dict(this_dict) is not None - def test_get_missing_proofing_method(self): + def test_get_missing_proofing_method(self) -> None: this = self.three.find("nin") + assert this this.verified_by = "foo" assert this.get_missing_proofing_method() is None this.verified_by = "bankid" diff --git a/src/eduid/userdb/tests/test_idp_user.py b/src/eduid/userdb/tests/test_idp_user.py index b8e0b9ec0..a61310665 100644 --- a/src/eduid/userdb/tests/test_idp_user.py +++ b/src/eduid/userdb/tests/test_idp_user.py @@ -9,7 +9,7 @@ class TestIdpUser(TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.saml_attribute_settings = SAMLAttributeSettings( default_eppn_scope="example.com", @@ -21,7 +21,7 @@ def setUp(self): authn_context_class=EduidAuthnContextClass.PASSWORD_PT, ) - def test_idp_user_to_attributes_all(self): + def test_idp_user_to_attributes_all(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) attributes = idp_user.to_saml_attributes(settings=self.saml_attribute_settings) @@ -37,6 +37,7 @@ def test_idp_user_to_attributes_all(self): continue self.assertIsNotNone(attributes.get(key), f"{key} is unexpectedly None") + assert idp_user.ladok expected = { "c": "se", "cn": "John Smith", @@ -60,7 +61,7 @@ def test_idp_user_to_attributes_all(self): } assert normalised_data(expected) == normalised_data(attributes), f"expected did not match {attributes}" - def test_idp_user_chosen_given_name(self): + def test_idp_user_chosen_given_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) idp_user.given_name = "John Jack" idp_user.chosen_given_name = "Jack" @@ -79,13 +80,13 @@ def test_idp_user_chosen_given_name(self): assert attributes["displayName"] == "John Jack Smith" assert attributes["cn"] == "John Jack Smith" - def test_idp_user_display_name(self): + def test_idp_user_display_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) attributes = idp_user.to_saml_attributes(settings=self.saml_attribute_settings) assert attributes["displayName"] == "John Smith" assert attributes["cn"] == "John Smith" - def test_idp_user_display_name_no_given_name(self): + def test_idp_user_display_name_no_given_name(self) -> None: idp_user = IdPUser.from_dict(UserFixtures().mocked_user_standard.to_dict()) idp_user.given_name = None idp_user.chosen_given_name = None diff --git a/src/eduid/userdb/tests/test_ladok.py b/src/eduid/userdb/tests/test_ladok.py index 56f48760d..0364da02a 100644 --- a/src/eduid/userdb/tests/test_ladok.py +++ b/src/eduid/userdb/tests/test_ladok.py @@ -10,7 +10,7 @@ class LadokTest(TestCase): def setUp(self) -> None: self.external_uuid = uuid4() - def test_create_ladok(self): + def test_create_ladok(self) -> None: university = University( ladok_name="AB", name=UniversityName(sv="Lärosätesnamn", en="University Name"), created_by="test created_by" ) diff --git a/src/eduid/userdb/tests/test_logs.py b/src/eduid/userdb/tests/test_logs.py index f8145a220..36524e194 100644 --- a/src/eduid/userdb/tests/test_logs.py +++ b/src/eduid/userdb/tests/test_logs.py @@ -30,15 +30,15 @@ class TestProofingLog(TestCase): user: User - def setUp(self): + def setUp(self) -> None: self.tmp_db = MongoTemporaryInstance.get_instance() self.proofing_log_db = ProofingLog(db_uri=self.tmp_db.uri) self.user = UserFixtures().mocked_user_standard - def tearDown(self): + def tearDown(self) -> None: self.proofing_log_db._drop_whole_collection() - def test_id_proofing_data(self): + def test_id_proofing_data(self) -> None: proofing_element = ProofingLogElement( eppn=self.user.eppn, created_by="test", proofing_method="test", proofing_version="test" ) @@ -52,7 +52,7 @@ def test_id_proofing_data(self): self.assertIsNotNone(hit["created_ts"]) self.assertEqual(hit["proofing_method"], "test") - def test_teleadress_proofing(self): + def test_teleadress_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -80,7 +80,7 @@ def test_teleadress_proofing(self): self.assertEqual(hit["proofing_method"], "TeleAdress") self.assertEqual(hit["proofing_version"], "test") - def test_teleadress_proofing_relation(self): + def test_teleadress_proofing_relation(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -111,7 +111,7 @@ def test_teleadress_proofing_relation(self): self.assertEqual(hit["proofing_method"], "TeleAdress") self.assertEqual(hit["proofing_version"], "test") - def test_letter_proofing(self): + def test_letter_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -140,7 +140,7 @@ def test_letter_proofing(self): self.assertEqual(hit["proofing_method"], "letter") self.assertEqual(hit["proofing_version"], "test") - def test_mail_address_proofing(self): + def test_mail_address_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -165,7 +165,7 @@ def test_mail_address_proofing(self): self.assertEqual(hit["proofing_method"], "e-mail") self.assertEqual(hit["mail_address"], "some_mail_address") - def test_phone_number_proofing(self): + def test_phone_number_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -191,7 +191,7 @@ def test_phone_number_proofing(self): self.assertEqual(hit["phone_number"], "some_phone_number") self.assertEqual(hit["proofing_version"], "test") - def test_se_leg_proofing(self): + def test_se_leg_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -222,7 +222,7 @@ def test_se_leg_proofing(self): self.assertEqual(hit["proofing_method"], "se-leg") self.assertEqual(hit["proofing_version"], "test") - def test_se_leg_proofing_freja(self): + def test_se_leg_proofing_freja(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -254,7 +254,7 @@ def test_se_leg_proofing_freja(self): self.assertEqual(hit["proofing_method"], "se-leg") self.assertEqual(hit["proofing_version"], "test") - def test_ladok_proofing(self): + def test_ladok_proofing(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -283,7 +283,7 @@ def test_ladok_proofing(self): self.assertEqual(hit["ladok_name"], "AB") self.assertEqual(hit["proofing_version"], "test") - def test_blank_string_proofing_data(self): + def test_blank_string_proofing_data(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", @@ -305,20 +305,20 @@ def test_blank_string_proofing_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_boolean_false_proofing_data(self): + def test_boolean_false_proofing_data(self) -> None: data = { "eppn": self.user.eppn, "created_by": "test", "proofing_version": "test", "reference": "reference id", } - proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=0) + proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=0) # type: ignore[arg-type] self.assertTrue(self.proofing_log_db.save(proofing_element)) - proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=False) + proofing_element = PhoneNumberProofing.model_construct(**data, phone_number=False) # type: ignore[arg-type] self.assertTrue(self.proofing_log_db.save(proofing_element)) - def test_deregistered_proofing_data(self): + def test_deregistered_proofing_data(self) -> None: proofing_element = NinNavetProofingLogElement( eppn=self.user.eppn, created_by="test", @@ -343,11 +343,11 @@ def test_deregistered_proofing_data(self): class TestUserChangeLog(TestCase): - def setUp(self): + def setUp(self) -> None: self.tmp_db = MongoTemporaryInstance.get_instance() self.user_log_db = UserChangeLog(db_uri=self.tmp_db.uri) - def tearDown(self): + def tearDown(self) -> None: self.user_log_db._drop_whole_collection() def _insert_log_fixtures(self): @@ -375,7 +375,7 @@ def _insert_log_fixtures(self): assert res[0]["eduPersonPrincipalName"] == "hubba-bubba" assert res[1]["eduPersonPrincipalName"] == "hubba-bubba" - def test_get_by_eppn(self): + def test_get_by_eppn(self) -> None: self._insert_log_fixtures() res_1 = self.user_log_db.get_by_eppn("hubba-bubba") diff --git a/src/eduid/userdb/tests/test_mail.py b/src/eduid/userdb/tests/test_mail.py index 7c8a5bc95..d5a42d182 100644 --- a/src/eduid/userdb/tests/test_mail.py +++ b/src/eduid/userdb/tests/test_mail.py @@ -11,6 +11,7 @@ from eduid.common.misc.timeutil import utc_now from eduid.common.testing_base import normalised_data from eduid.userdb import PhoneNumber +from eduid.userdb.element import ElementKey from eduid.userdb.mail import MailAddress, MailAddressList __author__ = "ft" @@ -35,37 +36,39 @@ class TestMailAddressList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None self.empty = MailAddressList() self.one = MailAddressList.from_list_of_dicts([_one_dict]) self.two = MailAddressList.from_list_of_dicts([_one_dict, _two_dict]) self.three = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): MailAddressList(elements="bad input data") with pytest.raises(ValidationError): MailAddressList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_empty_to_list_of_dicts(self): + def test_empty_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("ft@one.example.org") + assert match self.assertIsInstance(match, MailAddress) self.assertEqual(match.email, "ft@one.example.org") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("ft@two.example.org") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -73,8 +76,10 @@ def test_add(self): assert obtained == expected, "Wrong data after adding mail address to list" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: + assert self.two.primary dup = self.two.find(self.two.primary.email) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -89,8 +94,9 @@ def test_add_duplicate(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_mailaddress(self): + def test_add_mailaddress(self) -> None: third = self.three.find("ft@three.example.org") + assert third this = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -98,20 +104,20 @@ def test_add_mailaddress(self): assert obtained == expected, "Wrong data in mail address list" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = eduid.userdb.mail.address_from_dict( {"email": "ft@primary.example.org", "verified": True, "primary": True} ) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError): - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] - def test_remove(self): - self.three.remove("ft@three.example.org") + def test_remove(self) -> None: + self.three.remove(ElementKey("ft@three.example.org")) now_two = self.three expected = self.two.to_list_of_dicts() @@ -119,51 +125,58 @@ def test_remove(self): assert obtained == expected, "Wrong data after removing email from list" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("foo@no-such-address.example.org") + self.one.remove(ElementKey("foo@no-such-address.example.org")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with pytest.raises( eduid.userdb.element.PrimaryElementViolation, match="Removing the primary element is not allowed" ): self.two.remove(self.two.primary.key) - def test_remove_primary_single(self): - self.one.remove(self.one.primary.email) + def test_remove_primary_single(self) -> None: + assert self.one.primary + self.one.remove(ElementKey(self.one.primary.email)) now_empty = self.one self.assertEqual([], now_empty.to_list()) - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.email, "ft@one.example.org") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary - self.one.set_primary(match.email) + assert match + self.one.set_primary(ElementKey(match.email)) match = self.two.primary - self.two.set_primary(match.email) + assert match + self.two.set_primary(ElementKey(match.email)) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.set_primary("foo@no-such-address.example.org") + self.one.set_primary(ElementKey("foo@no-such-address.example.org")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.three.set_primary("ft@three.example.org") + self.three.set_primary(ElementKey("ft@three.example.org")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.email, "ft@one.example.org") - self.two.set_primary("ft@two.example.org") + self.two.set_primary(ElementKey("ft@two.example.org")) updated = self.two.primary + assert updated self.assertEqual(updated.email, "ft@two.example.org") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -171,7 +184,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): MailAddressList.from_list_of_dicts([one, two]) - def test_bad_input_unverified_primary(self): + def test_bad_input_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -179,20 +192,21 @@ def test_bad_input_unverified_primary(self): class TestMailAddress(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = MailAddressList() self.one = MailAddressList.from_list_of_dicts([_one_dict]) self.two = MailAddressList.from_list_of_dicts([_one_dict, _two_dict]) self.three = MailAddressList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the MailAddress. """ address = self.two.primary + assert address self.assertEqual(address.key, address.email) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -204,7 +218,7 @@ def test_parse_cycle(self): assert cycled == expected - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -219,7 +233,7 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_bad_input_type(self): + def test_bad_input_type(self) -> None: one = copy.deepcopy(_one_dict) one["email"] = False with pytest.raises(ValidationError) as exc_info: @@ -237,46 +251,54 @@ def test_bad_input_type(self): ] ), f"Wrong error message: {exc_info.value.errors()}" - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_verified_ts(self): + def test_verified_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_ts = utc_now() self.assertIsInstance(this.verified_ts, datetime.datetime) - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.verified_ts = utc_now() - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("ft@three.example.org") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("ft@three.example.org") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_uppercase_email_address(self): + def test_uppercase_email_address(self) -> None: address = "UPPERCASE@example.com" mail_address = MailAddress(email=address) self.assertEqual(address.lower(), mail_address.email) diff --git a/src/eduid/userdb/tests/test_nin.py b/src/eduid/userdb/tests/test_nin.py index a327be99e..8e0967917 100644 --- a/src/eduid/userdb/tests/test_nin.py +++ b/src/eduid/userdb/tests/test_nin.py @@ -34,37 +34,39 @@ class TestNinList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # Make pytest show full diffs self.empty = NinList() self.one = NinList.from_list_of_dicts([_one_dict]) self.two = NinList.from_list_of_dicts([_one_dict, _two_dict]) self.three = NinList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): NinList(elements="bad input data") with pytest.raises(ValidationError): NinList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: self.assertEqual([], self.empty.to_list(), list) self.assertIsInstance(self.one.to_list(), list) self.assertEqual(1, len(self.one.to_list())) - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: self.assertEqual([], self.empty.to_list_of_dicts(), list) - def test_find(self): + def test_find(self) -> None: match = self.one.find("197801011234") + assert match self.assertIsInstance(match, Nin) self.assertEqual(match.number, "197801011234") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("197802022345") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() @@ -91,8 +93,9 @@ def test_add_duplicate(self) -> None: ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_nin(self): + def test_add_nin(self) -> None: third = self.three.find("197803033456") + assert third this = NinList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -100,18 +103,18 @@ def test_add_nin(self): assert obtained == expected, "List with added nin has unexpected data" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = eduid.userdb.nin.nin_from_dict({"number": "+46700000009", "verified": True, "primary": True}) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: """Test adding a phone number to the nin-list. Specifically phone, since pydantic can coerce it into a nin since they both have the 'number' field. """ new = PhoneNumber(number="+4612345678") with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -123,7 +126,7 @@ def test_add_wrong_type(self): ], ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): + def test_remove(self) -> None: self.three.remove(ElementKey("197803033456")) now_two = self.three @@ -132,49 +135,56 @@ def test_remove(self): assert obtained == expected, "List with removed NIN has unexpected data" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.remove(ElementKey("+46709999999")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.two.remove(self.two.primary.key) - def test_remove_primary_single(self): + def test_remove_primary_single(self) -> None: + assert self.one.primary self.one.remove(self.one.primary.key) now_empty = self.one self.assertEqual([], now_empty.to_list()) - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.number, "197801011234") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary + assert match self.one.set_primary(match.key) match = self.two.primary + assert match self.two.set_primary(match.key) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): self.one.set_primary(ElementKey("+46709999999")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.three.set_primary(ElementKey("197803033456")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.number, "197801011234") self.two.set_primary(ElementKey("197802022345")) updated = self.two.primary + assert updated self.assertEqual(updated.number, "197802022345") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -182,7 +192,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): NinList.from_list_of_dicts([one, two]) - def test_bad_input_unverified_primary(self): + def test_bad_input_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -190,20 +200,21 @@ def test_bad_input_unverified_primary(self): class TestNin(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = NinList() self.one = NinList.from_list_of_dicts([_one_dict]) self.two = NinList.from_list_of_dicts([_one_dict, _two_dict]) self.three = NinList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the Nin. """ address = self.two.primary + assert address self.assertEqual(address.key, address.number) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -211,44 +222,52 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(NinList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("197803033456") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("197803033456") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("197803033456") + assert this this.verified_by = "unit test" this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("197803033456") + assert this now = datetime.datetime.utcnow() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("197803033456") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_modify_created_by(self): + def test_modify_created_by(self) -> None: this = self.three.find("197803033456") + assert this this.created_by = "unit test" assert this.created_by == "unit test" - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("197803033456") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_orcid.py b/src/eduid/userdb/tests/test_orcid.py index 596f70bbf..3069fb846 100644 --- a/src/eduid/userdb/tests/test_orcid.py +++ b/src/eduid/userdb/tests/test_orcid.py @@ -37,7 +37,8 @@ class TestOrcid(unittest.TestCase): maxDiff = None - def test_id_token(self): + def test_id_token(self) -> None: + assert isinstance(token_response["id_token"], dict) id_token_data = token_response["id_token"] id_token_data["created_by"] = "test" id_token_1 = OidcIdToken.from_dict(id_token_data) @@ -65,9 +66,10 @@ def test_id_token(self): assert dict_1 == dict_2 with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - OidcIdToken.from_dict(None) + OidcIdToken.from_dict(None) # type: ignore[arg-type] - def test_oidc_authz(self): + def test_oidc_authz(self) -> None: + assert isinstance(token_response["id_token"], dict) id_token_data = token_response["id_token"] id_token_data["created_by"] = "test" id_token = OidcIdToken.from_dict(token_response["id_token"]) @@ -96,9 +98,10 @@ def test_oidc_authz(self): assert dict_1 == dict_2 with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - OidcAuthorization.from_dict(None) + OidcAuthorization.from_dict(None) # type: ignore[arg-type] - def test_orcid(self): + def test_orcid(self) -> None: + assert isinstance(token_response["id_token"], dict) token_response["id_token"]["created_by"] = "test" token_response["created_by"] = "test" oidc_authz = OidcAuthorization.from_dict(token_response) @@ -135,4 +138,4 @@ def test_orcid(self): ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" with pytest.raises(eduid.userdb.exceptions.UserDBValueError): - Orcid.from_dict(None) + Orcid.from_dict(None) # type: ignore[arg-type] diff --git a/src/eduid/userdb/tests/test_password.py b/src/eduid/userdb/tests/test_password.py index 867bfe602..48e267d88 100644 --- a/src/eduid/userdb/tests/test_password.py +++ b/src/eduid/userdb/tests/test_password.py @@ -3,7 +3,7 @@ from bson.objectid import ObjectId -from eduid.userdb.credentials import CredentialList +from eduid.userdb.credentials import CredentialList, Password __author__ = "lundberg" @@ -17,20 +17,22 @@ class TestPassword(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the Password. """ password = self.one.find(str(ObjectId("55002741d00690878ae9b600"))) + assert password + assert isinstance(password, Password) self.assertEqual(password.key, password.credential_id) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -38,11 +40,13 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(str(ObjectId("55002741d00690878ae9b600"))) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(str(ObjectId("55002741d00690878ae9b600"))) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_phone.py b/src/eduid/userdb/tests/test_phone.py index ae29bc7a5..6314b86ec 100644 --- a/src/eduid/userdb/tests/test_phone.py +++ b/src/eduid/userdb/tests/test_phone.py @@ -10,6 +10,7 @@ from eduid.common.misc.timeutil import utc_now from eduid.common.testing_base import normalised_data from eduid.userdb import MailAddress +from eduid.userdb.element import ElementKey from eduid.userdb.phone import PhoneNumber, PhoneNumberList __author__ = "ft" @@ -40,26 +41,26 @@ class TestPhoneNumberList(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = PhoneNumberList() self.one = PhoneNumberList.from_list_of_dicts([_one_dict]) self.two = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict]) self.three = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) self.four = PhoneNumberList.from_list_of_dicts([_three_dict, _four_dict]) - def test_init_bad_data(self): + def test_init_bad_data(self) -> None: with pytest.raises(ValidationError): PhoneNumberList(elements="bad input data") with pytest.raises(ValidationError): PhoneNumberList(elements=["bad input data"]) - def test_to_list(self): + def test_to_list(self) -> None: assert self.empty.to_list_of_dicts() == [] assert isinstance(self.one.to_list(), list) assert len(self.one.to_list()) == 1 - def test_to_list_of_dicts(self): + def test_to_list_of_dicts(self) -> None: assert self.empty.to_list_of_dicts() == [] one_dict_list = self.one.to_list_of_dicts() @@ -67,23 +68,27 @@ def test_to_list_of_dicts(self): assert one_dict_list == expected - def test_find(self): + def test_find(self) -> None: match = self.one.find("+46700000001") + assert match self.assertIsInstance(match, PhoneNumber) self.assertEqual(match.number, "+46700000001") self.assertEqual(match.is_verified, True) self.assertEqual(match.verified_ts, None) - def test_add(self): + def test_add(self) -> None: second = self.two.find("+46700000002") + assert second self.one.add(second) expected = self.two.to_list_of_dicts() got = self.one.to_list_of_dicts() assert got == expected, "Adding a phone number to a list results in wrong data" - def test_add_duplicate(self): + def test_add_duplicate(self) -> None: + assert self.two.primary dup = self.two.find(self.two.primary.number) + assert dup with pytest.raises(ValidationError) as exc_info: self.two.add(dup) @@ -98,8 +103,9 @@ def test_add_duplicate(self): ] ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_add_phonenumber(self): + def test_add_phonenumber(self) -> None: third = self.three.find("+46700000003") + assert third this = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, third.to_dict()]) expected = self.three.to_list_of_dicts() @@ -107,15 +113,15 @@ def test_add_phonenumber(self): assert got == expected, "Phone number list contains wrong data" - def test_add_another_primary(self): + def test_add_another_primary(self) -> None: new = PhoneNumber(number="+46700000009", is_verified=True, is_primary=True) with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): self.one.add(new) - def test_add_wrong_type(self): + def test_add_wrong_type(self) -> None: new = MailAddress(email="ft@example.org") with pytest.raises(ValidationError) as exc_info: - self.one.add(new) + self.one.add(new) # type: ignore[arg-type] assert normalised_data(exc_info.value.errors(), exclude_keys=["input", "url"]) == normalised_data( [ { @@ -127,8 +133,8 @@ def test_add_wrong_type(self): ] ), f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['input', 'url'])}" - def test_remove(self): - self.three.remove("+46700000003") + def test_remove(self) -> None: + self.three.remove(ElementKey("+46700000003")) now_two = self.three expected = self.two.to_list_of_dicts() @@ -136,41 +142,45 @@ def test_remove(self): assert got == expected, "Phone list has wrong data after removing phone" - def test_remove_unknown(self): + def test_remove_unknown(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.remove("+46709999999") + self.one.remove(ElementKey("+46709999999")) - def test_remove_primary(self): + def test_remove_primary(self) -> None: + assert self.two.primary with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.two.remove(self.two.primary.number) + self.two.remove(ElementKey(self.two.primary.number)) - def test_remove_primary_single(self): - self.one.remove(self.one.primary.number) + def test_remove_primary_single(self) -> None: + assert self.one.primary + self.one.remove(ElementKey(self.one.primary.number)) now_empty = self.one assert now_empty.to_list() == [] - def test_remove_all_mix(self): + def test_remove_all_mix(self) -> None: # First, remove all numbers except the primary for mobile in self.three.to_list(): if not mobile.is_primary: self.three.remove(mobile.key) # Now, remove the primary number (which can't be removed until it is the last element) + assert self.three.primary self.three.remove(self.three.primary.key) assert self.three.to_list() == [] - def test_remove_all_no_verified(self): + def test_remove_all_no_verified(self) -> None: verified = self.four.verified if verified: for mobile in verified: if not mobile.is_primary: - self.four.remove(mobile.number) - self.four.remove(self.four.primary.number) + self.four.remove(ElementKey(mobile.number)) + assert self.four.primary + self.four.remove(ElementKey(self.four.primary.number)) for mobile in self.four.to_list(): - self.four.remove(mobile.number) + self.four.remove(ElementKey(mobile.number)) self.assertEqual([], self.four.to_list()) - def test_unverify_all(self): + def test_unverify_all(self) -> None: verified = self.three.verified for mobile in verified: @@ -180,36 +190,41 @@ def test_unverify_all(self): verified_now = self.three.verified assert verified_now == [] - def test_primary(self): + def test_primary(self) -> None: match = self.one.primary + assert match self.assertEqual(match.number, "+46700000001") - def test_empty_primary(self): + def test_empty_primary(self) -> None: self.assertEqual(None, self.empty.primary) - def test_set_primary_to_same(self): + def test_set_primary_to_same(self) -> None: match = self.one.primary - self.one.set_primary(match.number) + assert match + self.one.set_primary(ElementKey(match.number)) match = self.two.primary - self.two.set_primary(match.number) + assert match + self.two.set_primary(ElementKey(match.number)) - def test_set_unknown_as_primary(self): + def test_set_unknown_as_primary(self) -> None: with self.assertRaises(eduid.userdb.exceptions.UserDBValueError): - self.one.set_primary("+46709999999") + self.one.set_primary(ElementKey("+46709999999")) - def test_set_unverified_as_primary(self): + def test_set_unverified_as_primary(self) -> None: with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): - self.three.set_primary("+46700000003") + self.three.set_primary(ElementKey("+46700000003")) - def test_change_primary(self): + def test_change_primary(self) -> None: match = self.two.primary + assert match self.assertEqual(match.number, "+46700000001") - self.two.set_primary("+46700000002") + self.two.set_primary(ElementKey("+46700000002")) updated = self.two.primary + assert updated self.assertEqual(updated.number, "+46700000002") - def test_bad_input_two_primary(self): + def test_bad_input_two_primary(self) -> None: one = copy.deepcopy(_one_dict) two = copy.deepcopy(_two_dict) one["primary"] = True @@ -217,7 +232,7 @@ def test_bad_input_two_primary(self): with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): PhoneNumberList.from_list_of_dicts([one, two]) - def test_unverified_primary(self): + def test_unverified_primary(self) -> None: one = copy.deepcopy(_one_dict) one["verified"] = False with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): @@ -225,22 +240,23 @@ def test_unverified_primary(self): class TestPhoneNumber(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = PhoneNumberList() self.one = PhoneNumberList.from_list_of_dicts([_one_dict]) self.two = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict]) self.three = PhoneNumberList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by PrimaryElementList) works for the PhoneNumber. """ address = self.two.primary + assert address self.assertEqual(address.key, address.number) - def test_create_phone_number(self): - one = copy.deepcopy(_one_dict) - one = PhoneNumber.from_dict(one) + def test_create_phone_number(self) -> None: + one_copy = copy.deepcopy(_one_dict) + one = PhoneNumber.from_dict(one_copy) # remove added timestamp one_dict = one.to_dict() @@ -248,7 +264,7 @@ def test_create_phone_number(self): assert _one_dict["verified"] == one_dict["verified"], "Created phone has wrong is_verified" assert _one_dict["number"] == one_dict["number"], "Created phone has wrong number" - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -256,7 +272,7 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(PhoneNumberList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -271,39 +287,46 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_changing_is_verified_on_primary(self): + def test_changing_is_verified_on_primary(self) -> None: this = self.one.primary + assert this with self.assertRaises(eduid.userdb.element.PrimaryElementViolation): this.is_verified = False - def test_changing_is_verified(self): + def test_changing_is_verified(self) -> None: this = self.three.find("+46700000003") + assert this this.is_verified = False # was False already this.is_verified = True - def test_verified_by(self): + def test_verified_by(self) -> None: this = self.three.find("+46700000003") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") - def test_modify_verified_by(self): + def test_modify_verified_by(self) -> None: this = self.three.find("+46700000003") + assert this this.verified_by = "unit test" self.assertEqual(this.verified_by, "unit test") this.verified_by = "test unit" self.assertEqual(this.verified_by, "test unit") - def test_modify_verified_ts(self): + def test_modify_verified_ts(self) -> None: this = self.three.find("+46700000003") + assert this now = utc_now() this.verified_ts = now self.assertEqual(this.verified_ts, now) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find("+46700000003") + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find("+46700000003") + assert this self.assertIsInstance(this.created_ts, datetime.datetime) diff --git a/src/eduid/userdb/tests/test_profile.py b/src/eduid/userdb/tests/test_profile.py index 7af2553f0..8b8d250e0 100644 --- a/src/eduid/userdb/tests/test_profile.py +++ b/src/eduid/userdb/tests/test_profile.py @@ -13,7 +13,7 @@ class ProfileTest(TestCase): - def test_create_profile(self): + def test_create_profile(self) -> None: profile = Profile( owner="test owner", profile_schema="test schema", @@ -28,7 +28,7 @@ def test_create_profile(self): self.assertIn(key, profile.profile_data) self.assertEqual(value, profile.profile_data[key]) - def test_profile_list(self): + def test_profile_list(self) -> None: profile = Profile( owner="test owner 1", profile_schema="test schema", @@ -48,12 +48,12 @@ def test_profile_list(self): self.assertIsNotNone(profile_list.find("test owner 1")) self.assertIsNotNone(profile_list.find("test owner 2")) - def test_empty_profile_list(self): + def test_empty_profile_list(self) -> None: profile_list = ProfileList() self.assertIsNotNone(profile_list) self.assertEqual(profile_list.count, 0) - def test_profile_list_owner_conflict(self): + def test_profile_list_owner_conflict(self) -> None: profile = Profile( owner="test owner 1", profile_schema="test schema", diff --git a/src/eduid/userdb/tests/test_proofing.py b/src/eduid/userdb/tests/test_proofing.py index c40cdd2e0..81c56527b 100644 --- a/src/eduid/userdb/tests/test_proofing.py +++ b/src/eduid/userdb/tests/test_proofing.py @@ -73,7 +73,7 @@ def _test_create_letterproofingstate(self, state: LetterProofingState, nin_expec sorted(_proofing_letter_expected_keys), ) - def test_create_letterproofingstate_with_ninproofingelement_from_dict(self): + def test_create_letterproofingstate_with_ninproofingelement_from_dict(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -102,7 +102,7 @@ def test_create_letterproofingstate_with_ninproofingelement_from_dict(self): self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_created_ts(self): + def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_created_ts(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -130,7 +130,7 @@ def test_create_letterproofingstate_with_ninproofingelement_from_dict_with_creat self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_letterproofingstate(self): + def test_create_letterproofingstate(self) -> None: """ """ state = LetterProofingState( eppn=EPPN, @@ -155,7 +155,7 @@ def test_create_letterproofingstate(self): self._test_create_letterproofingstate(state, _nin_expected_keys) - def test_create_oidcproofingstate(self): + def test_create_oidcproofingstate(self) -> None: """ { 'eduPersonPrincipalName': 'foob-arra', @@ -183,7 +183,7 @@ def test_create_oidcproofingstate(self): ["_id", "eduPersonPrincipalName", "modified_ts", "nin", "nonce", "state", "token"], ) - def test_proofing_state_expiration(self): + def test_proofing_state_expiration(self) -> None: state = ProofingState(id=None, eppn=EPPN, modified_ts=datetime.now(tz=None)) self.assertFalse(state.is_expired(1)) diff --git a/src/eduid/userdb/tests/test_resetpw.py b/src/eduid/userdb/tests/test_resetpw.py index cba508ab3..6062110cf 100644 --- a/src/eduid/userdb/tests/test_resetpw.py +++ b/src/eduid/userdb/tests/test_resetpw.py @@ -1,17 +1,20 @@ from datetime import timedelta from eduid.userdb.reset_password import ResetPasswordEmailAndPhoneState, ResetPasswordEmailState, ResetPasswordStateDB +from eduid.userdb.reset_password.element import CodeElement from eduid.userdb.testing import MongoTestCase class TestResetPasswordStateDB(MongoTestCase): - def setUp(self): + def setUp(self) -> None: # type: ignore[override] super().setUp() self.resetpw_db = ResetPasswordStateDB(self.tmp_db.uri, "eduid_reset_password") - def test_email_state(self): + def test_email_state(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) self.resetpw_db.save(email_state, is_in_database=False) @@ -25,9 +28,11 @@ def test_email_state(self): self.assertTrue(state.email_code.is_expired(timedelta(0))) self.assertFalse(state.email_code.is_expired(timedelta(1))) - def test_email_state_get_by_code(self): + def test_email_state_get_by_code(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) self.resetpw_db.save(email_state, is_in_database=False) @@ -39,9 +44,11 @@ def test_email_state_get_by_code(self): self.assertEqual(state.eppn, "hubba-bubba") self.assertEqual(state.generated_password, False) - def test_email_state_generated_pw(self): + def test_email_state_generated_pw(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) email_state.generated_password = True @@ -52,9 +59,11 @@ def test_email_state_generated_pw(self): self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.generated_password, True) - def test_email_state_extra_security(self): + def test_email_state_extra_security(self) -> None: email_state = ResetPasswordEmailState( - eppn="hubba-bubba", email_address="johnsmith@example.com", email_code="dummy-code" + eppn="hubba-bubba", + email_address="johnsmith@example.com", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), ) email_state.extra_security = {"phone_numbers": [{"number": "+99999999999", "primary": True, "verified": True}]} @@ -66,19 +75,20 @@ def test_email_state_extra_security(self): self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.extra_security["phone_numbers"][0]["number"], "+99999999999") - def test_email_and_phone_state(self): + def test_email_and_phone_state(self) -> None: email_state = ResetPasswordEmailAndPhoneState( eppn="hubba-bubba", email_address="johnsmith@example.com", - email_code="dummy-code", + email_code=CodeElement.parse(application="test", code_or_element="dummy-code"), phone_number="+99999999999", - phone_code="dummy-phone-code", + phone_code=CodeElement.parse(application="test", code_or_element="dummy-phone-code"), ) self.resetpw_db.save(email_state, is_in_database=False) state = self.resetpw_db.get_state_by_eppn("hubba-bubba") assert state is not None + assert isinstance(state, ResetPasswordEmailAndPhoneState) self.assertEqual(state.email_address, "johnsmith@example.com") self.assertEqual(state.email_code.code, "dummy-code") self.assertEqual(state.phone_number, "+99999999999") diff --git a/src/eduid/userdb/tests/test_signup_invite.py b/src/eduid/userdb/tests/test_signup_invite.py index 5df717f11..684c439a8 100644 --- a/src/eduid/userdb/tests/test_signup_invite.py +++ b/src/eduid/userdb/tests/test_signup_invite.py @@ -6,7 +6,7 @@ class TestSignupInvite(TestCase): - def test_scim_invite(self): + def test_scim_invite(self) -> None: invite = Invite( invite_type=InviteType.SCIM, invite_reference=SCIMReference(data_owner="test_data_owner", scim_id=uuid4()), diff --git a/src/eduid/userdb/tests/test_signup_user.py b/src/eduid/userdb/tests/test_signup_user.py index 322a93efd..3483a243b 100644 --- a/src/eduid/userdb/tests/test_signup_user.py +++ b/src/eduid/userdb/tests/test_signup_user.py @@ -7,23 +7,23 @@ class TestSignupUser(TestCase): - def setUp(self): + def setUp(self) -> None: self.user = UserFixtures().new_signup_user_example self.user_data = self.user.to_dict() - def test_proper_user(self): + def test_proper_user(self) -> None: self.assertEqual(self.user.user_id, self.user_data["_id"]) self.assertEqual(self.user.eppn, self.user_data["eduPersonPrincipalName"]) - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: user = SignupUser(user_id=self.user.user_id, eppn=self.user.eppn) self.assertEqual(user.user_id, self.user.user_id) self.assertEqual(user.eppn, self.user.eppn) - def test_missing_id(self): + def test_missing_id(self) -> None: user = SignupUser(eppn=self.user.eppn) self.assertNotEqual(user.user_id, self.user.user_id) - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: with self.assertRaises(ValidationError): - SignupUser(user_id=self.user.user_id) + SignupUser(user_id=self.user.user_id) # type: ignore[call-arg] diff --git a/src/eduid/userdb/tests/test_support_models.py b/src/eduid/userdb/tests/test_support_models.py index 7bfeae679..1b6b21062 100644 --- a/src/eduid/userdb/tests/test_support_models.py +++ b/src/eduid/userdb/tests/test_support_models.py @@ -5,24 +5,24 @@ class TestSupportUsers(TestCase): - def setUp(self): + def setUp(self) -> None: self.users = UserFixtures() - def test_support_user(self): + def test_support_user(self) -> None: user = models.SupportUserFilter(self.users.new_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_signup_user(self): + def test_support_signup_user(self) -> None: user = models.SupportSignupUserFilter(self.users.new_signup_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_completed_signup_user(self): + def test_support_completed_signup_user(self) -> None: user = models.SupportSignupUserFilter(self.users.new_completed_signup_user_example.to_dict()) self.assertNotIn("_id", user) self.assertNotIn("letter_proofing_data", user) @@ -32,12 +32,14 @@ def test_support_completed_signup_user(self): The assertion is here only for good measure to make sure that the right example data is being used. """ - self.assertTrue(len(user.get("pending_mail_address")) == 0) + pending = user.get("pending_mail_address") + assert pending is not None + self.assertTrue(len(pending) == 0) for password in user["passwords"]: self.assertNotIn("salt", password) - def test_support_user_authn_info(self): + def test_support_user_authn_info(self) -> None: raw_data = { "_id": "5c5b027c20d6b6000db13187", "fail_count": {"201902": 1, "201903": 0}, diff --git a/src/eduid/userdb/tests/test_tou.py b/src/eduid/userdb/tests/test_tou.py index b70a34292..428738135 100644 --- a/src/eduid/userdb/tests/test_tou.py +++ b/src/eduid/userdb/tests/test_tou.py @@ -8,6 +8,7 @@ from eduid.userdb.actions.tou import ToUUser from eduid.userdb.credentials import CredentialList +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.event import Event, EventList from eduid.userdb.exceptions import UserMissingData from eduid.userdb.fixtures.users import UserFixtures @@ -46,20 +47,20 @@ class TestToUEvent(TestCase): - def setUp(self): - self.empty = EventList() + def setUp(self) -> None: + self.empty: EventList = EventList() self.one = ToUList.from_list_of_dicts([_one_dict]) self.two = ToUList.from_list_of_dicts([_one_dict, _two_dict]) self.three = ToUList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by ElementList) works for the ToUEvent. """ event = self.two.to_list()[0] self.assertEqual(event.key, event.event_id) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -68,22 +69,24 @@ def test_parse_cycle(self): new_list = ToUList.from_list_of_dicts(this_dict) assert new_list.to_list_of_dicts() == this.to_list_of_dicts() - def test_created_by(self): + def test_created_by(self) -> None: this = Event.from_dict(dict(created_by=None, event_type="test_event")) this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_event_type(self): + def test_event_type(self) -> None: this = self.one.to_list()[0] self.assertEqual(this.event_type, "tou_event") - def test_reaccept_tou(self): + def test_reaccept_tou(self) -> None: three_years = timedelta(days=3 * 365) one_day = timedelta(days=1) # set modified_ts to both sides of three years ago _two_dict["modified_ts"] = utc_now() - three_years + one_day _three_dict["modified_ts"] = utc_now() - three_years - one_day + assert isinstance(_two_dict["modified_ts"], datetime) assert _two_dict["modified_ts"] + three_years > utc_now() + assert isinstance(_three_dict["modified_ts"], datetime) assert _three_dict["modified_ts"] + three_years < utc_now() # check if the TOU needs to be accepted with an interval of three years @@ -95,16 +98,16 @@ def test_reaccept_tou(self): class TestTouUser(TestCase): user: User - def setUp(self): + def setUp(self) -> None: self.user = UserFixtures().new_user_example - def test_proper_user(self): + def test_proper_user(self) -> None: userdata = self.user.to_dict() userdata["tou"] = [copy.deepcopy(_one_dict)] user = ToUUser.from_dict(data=userdata) self.assertEqual(user.tou.to_list_of_dicts()[0]["version"], "1") - def test_proper_new_user(self): + def test_proper_new_user(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) userdata = self.user.to_dict() @@ -114,31 +117,31 @@ def test_proper_new_user(self): user = ToUUser(user_id=userid, eppn=eppn, tou=tou, credentials=passwords) self.assertEqual(user.tou.to_list_of_dicts()[0]["version"], "1") - def test_proper_new_user_no_id(self): + def test_proper_new_user_no_id(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList(elements=[ToUEvent.from_dict(one)]) userdata = self.user.to_dict() passwords = CredentialList.from_list_of_dicts(userdata["passwords"]) with self.assertRaises(ValidationError): - ToUUser(tou=tou, credentials=passwords) + ToUUser(tou=tou, credentials=passwords) # type: ignore[call-arg] - def test_proper_new_user_no_eppn(self): + def test_proper_new_user_no_eppn(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) userdata = self.user.to_dict() userid = userdata.pop("_id") passwords = CredentialList.from_list_of_dicts(userdata["passwords"]) with self.assertRaises(ValidationError): - ToUUser(user_id=userid, tou=tou, credentials=passwords) + ToUUser(user_id=userid, tou=tou, credentials=passwords) # type: ignore[call-arg] - def test_missing_eppn(self): + def test_missing_eppn(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUList.from_list_of_dicts([one]) with self.assertRaises(UserMissingData): - ToUUser.from_dict(data=dict(tou=tou, userid=self.user.user_id)) + ToUUser.from_dict(data=TUserDbDocument({"tou": tou, "userid": self.user.user_id})) - def test_missing_userid(self): + def test_missing_userid(self) -> None: one = copy.deepcopy(_one_dict) tou = ToUEvent.from_dict(one) with self.assertRaises(UserMissingData): - ToUUser.from_dict(data=dict(tou=[tou], eppn=self.user.eppn)) + ToUUser.from_dict(data=TUserDbDocument({"tou": [tou], "eppn": self.user.eppn})) diff --git a/src/eduid/userdb/tests/test_u2f.py b/src/eduid/userdb/tests/test_u2f.py index ffad00c48..0f5cb4160 100644 --- a/src/eduid/userdb/tests/test_u2f.py +++ b/src/eduid/userdb/tests/test_u2f.py @@ -43,17 +43,19 @@ def _keyid(key: dict[str, str]): class TestU2F(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the credential. """ this = self.one.find(_keyid(_one_dict)) + assert this + assert isinstance(this, U2F) self.assertEqual( this.key, _keyid( @@ -64,7 +66,7 @@ def test_key(self): ), ) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -72,7 +74,7 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_unknown_input_data(self): + def test_unknown_input_data(self) -> None: one = copy.deepcopy(_one_dict) one["foo"] = "bar" with pytest.raises(ValidationError) as exc_info: @@ -86,17 +88,20 @@ def test_unknown_input_data(self): } ], f"Wrong error message: {normalised_data(exc_info.value.errors(), exclude_keys=['url'])}" - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_proofing_method(self): + def test_proofing_method(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI self.assertEqual(this.proofing_method, CredentialProofingMethod.SWAMID_AL2_MFA_HI) this.proofing_method = CredentialProofingMethod.SWAMID_AL3_MFA @@ -104,8 +109,9 @@ def test_proofing_method(self): this.proofing_method = None self.assertEqual(this.proofing_method, None) - def test_proofing_version(self): + def test_proofing_version(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_version = "TEST" self.assertEqual(this.proofing_version, "TEST") this.proofing_version = "TEST2" @@ -113,10 +119,12 @@ def test_proofing_version(self): this.proofing_version = None self.assertEqual(this.proofing_version, None) - def test_swamid_al2_hi_to_swamid_al3_migration(self): + def test_swamid_al2_hi_to_swamid_al3_migration(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI this.is_verified = True load_save_cred_list = CredentialList.from_list_of_dicts([this.to_dict()]) load_save_cred = load_save_cred_list.find(_keyid(_three_dict)) + assert load_save_cred self.assertEqual(load_save_cred.proofing_method, CredentialProofingMethod.SWAMID_AL3_MFA) diff --git a/src/eduid/userdb/tests/test_user.py b/src/eduid/userdb/tests/test_user.py index e01ac391d..5c6c7f1a4 100644 --- a/src/eduid/userdb/tests/test_user.py +++ b/src/eduid/userdb/tests/test_user.py @@ -8,6 +8,7 @@ from eduid.userdb import NinIdentity, OidcAuthorization, OidcIdToken, Orcid from eduid.userdb.credentials import U2F, CredentialList, CredentialProofingMethod, Password +from eduid.userdb.db.base import TUserDbDocument from eduid.userdb.exceptions import EduIDUserDBError, UserHasNotCompletedSignup, UserIsRevoked from eduid.userdb.fixtures.identity import verified_nin_identity from eduid.userdb.fixtures.users import UserFixtures @@ -27,96 +28,100 @@ def _keyid(kh: str): class TestNewUser(unittest.TestCase): - def setUp(self): - self.data1 = { - "_id": ObjectId("547357c3d00690878ae9b620"), - "eduPersonPrincipalName": "guvat-nalif", - "givenName": "User", - "chosen_given_name": "User", - "legal_name": "User One", - "mail": "user@example.net", - "mailAliases": [ - { - "added_timestamp": datetime.fromisoformat("2014-12-18T11:25:19.804000"), - "email": "user@example.net", - "verified": True, - "primary": True, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-11-24T16:22:49.188000"), - "credential_id": "54735b588a7d2a2c4ec3e7d0", - "salt": "$NDNv1H1$315d7$32$32$", - "created_by": "dashboard", - "is_generated": False, - } - ], - "identities": [verified_nin_identity.to_dict()], - "subject": "physical person", - "surname": "One", - "eduPersonEntitlement": ["http://foo.example.org"], - "preferredLanguage": "en", - } + def setUp(self) -> None: + self.data1 = TUserDbDocument( + { + "_id": ObjectId("547357c3d00690878ae9b620"), + "eduPersonPrincipalName": "guvat-nalif", + "givenName": "User", + "chosen_given_name": "User", + "legal_name": "User One", + "mail": "user@example.net", + "mailAliases": [ + { + "added_timestamp": datetime.fromisoformat("2014-12-18T11:25:19.804000"), + "email": "user@example.net", + "verified": True, + "primary": True, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-11-24T16:22:49.188000"), + "credential_id": "54735b588a7d2a2c4ec3e7d0", + "salt": "$NDNv1H1$315d7$32$32$", + "created_by": "dashboard", + "is_generated": False, + } + ], + "identities": [verified_nin_identity.to_dict()], + "subject": "physical person", + "surname": "One", + "eduPersonEntitlement": ["http://foo.example.org"], + "preferredLanguage": "en", + } + ) - self.data2 = { - "_id": ObjectId("549190b5d00690878ae9b622"), - "displayName": "Some \xf6ne", - "eduPersonPrincipalName": "birub-gagoz", - "givenName": "Some", - "mail": "some.one@gmail.com", - "mailAliases": [ - {"email": "someone+test1@gmail.com", "verified": True}, - { - "added_timestamp": datetime.fromisoformat("2014-12-17T14:35:14.728000"), - "email": "some.one@gmail.com", - "verified": True, - }, - ], - "phone": [ - { - "created_ts": datetime.fromisoformat("2014-12-18T09:11:35.078000"), - "number": "+46702222222", - "primary": True, - "verified": True, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2015-02-11T13:58:42.327000"), - "id": ObjectId("54db60128a7d2a26e8690cda"), - "salt": "$NDNv1H1$db011fc$32$32$", - "is_generated": False, - "source": "dashboard", - }, - { - "version": "U2F_V2", - "app_id": "unit test", - "keyhandle": "U2F SWAMID AL3", - "public_key": "foo", - "verified": True, - "proofing_method": CredentialProofingMethod.SWAMID_AL3_MFA, - "proofing_version": "testing", - }, - ], - "profiles": [ - { - "created_by": "test application", - "created_ts": datetime.fromisoformat("2020-02-04T17:42:33.696751"), - "owner": "test owner 1", - "schema": "test schema", - "profile_data": { - "a_string": "I am a string", - "an_int": 3, - "a_list": ["eins", 2, "drei"], - "a_map": {"some": "data"}, + self.data2 = TUserDbDocument( + { + "_id": ObjectId("549190b5d00690878ae9b622"), + "displayName": "Some \xf6ne", + "eduPersonPrincipalName": "birub-gagoz", + "givenName": "Some", + "mail": "some.one@gmail.com", + "mailAliases": [ + {"email": "someone+test1@gmail.com", "verified": True}, + { + "added_timestamp": datetime.fromisoformat("2014-12-17T14:35:14.728000"), + "email": "some.one@gmail.com", + "verified": True, }, - } - ], - "preferredLanguage": "sv", - "surname": "\xf6ne", - "subject": "physical person", - } + ], + "phone": [ + { + "created_ts": datetime.fromisoformat("2014-12-18T09:11:35.078000"), + "number": "+46702222222", + "primary": True, + "verified": True, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2015-02-11T13:58:42.327000"), + "id": ObjectId("54db60128a7d2a26e8690cda"), + "salt": "$NDNv1H1$db011fc$32$32$", + "is_generated": False, + "source": "dashboard", + }, + { + "version": "U2F_V2", + "app_id": "unit test", + "keyhandle": "U2F SWAMID AL3", + "public_key": "foo", + "verified": True, + "proofing_method": CredentialProofingMethod.SWAMID_AL3_MFA, + "proofing_version": "testing", + }, + ], + "profiles": [ + { + "created_by": "test application", + "created_ts": datetime.fromisoformat("2020-02-04T17:42:33.696751"), + "owner": "test owner 1", + "schema": "test schema", + "profile_data": { + "a_string": "I am a string", + "an_int": 3, + "a_list": ["eins", 2, "drei"], + "a_map": {"some": "data"}, + }, + } + ], + "preferredLanguage": "sv", + "surname": "\xf6ne", + "subject": "physical person", + } + ) self._setup_user1() self._setup_user2() @@ -225,28 +230,28 @@ def _setup_user2(self): subject=SubjectType("physical person"), ) - def test_user_id(self): + def test_user_id(self) -> None: self.assertEqual(self.user1.user_id, self.data1["_id"]) - def test_eppn(self): + def test_eppn(self) -> None: self.assertEqual(self.user1.eppn, self.data1["eduPersonPrincipalName"]) - def test_given_name(self): + def test_given_name(self) -> None: self.assertEqual(self.user2.given_name, self.data2["givenName"]) - def test_chosen_given_name(self): + def test_chosen_given_name(self) -> None: self.assertEqual(self.user1.chosen_given_name, self.data1["chosen_given_name"]) - def test_surname(self): + def test_surname(self) -> None: self.assertEqual(self.user2.surname, self.data2["surname"]) - def test_legal_name(self): + def test_legal_name(self) -> None: self.assertEqual(self.user1.legal_name, self.data1["legal_name"]) - def test_mail_addresses(self): + def test_mail_addresses(self) -> None: self.assertEqual(self.user1.mail_addresses.primary.email, self.data1["mailAliases"][0]["email"]) - def test_passwords(self): + def test_passwords(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -260,7 +265,7 @@ def test_passwords(self): assert obtained == expected - def test_unknown_attributes(self): + def test_unknown_attributes(self) -> None: """ Test parsing a document with unknown data in it. """ @@ -270,16 +275,18 @@ def test_unknown_attributes(self): with self.assertRaises(ValidationError): User.from_dict(data) - def test_incomplete_signup_user(self): + def test_incomplete_signup_user(self) -> None: """ Test parsing the incomplete documents left in the central userdb by older Signup application. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "vohon-mufus", - "mail": "olle@example.org", - "mailAliases": [{"email": "olle@example.org", "verified": False}], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "vohon-mufus", + "mail": "olle@example.org", + "mailAliases": [{"email": "olle@example.org", "verified": False}], + } + ) with self.assertRaises(UserHasNotCompletedSignup): User.from_dict(data) data["subject"] = "physical person" # later signup added this attribute @@ -304,108 +311,122 @@ def test_incomplete_signup_user(self): assert obtained == expected - def test_revoked_user(self): + def test_revoked_user(self) -> None: """ Test ability to identify revoked users. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "binib-mufus", - "revoked_ts": datetime.fromisoformat("2015-05-26T08:33:56.826000"), - "passwords": [], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "binib-mufus", + "revoked_ts": datetime.fromisoformat("2015-05-26T08:33:56.826000"), + "passwords": [], + } + ) with self.assertRaises(UserIsRevoked): User.from_dict(data) - def test_user_with_no_primary_mail(self): + def test_user_with_no_primary_mail(self) -> None: mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mailAliases": [{"email": mail, "verified": True}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mailAliases": [{"email": mail, "verified": True}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_user_with_indirectly_verified_primary_mail(self): + def test_user_with_indirectly_verified_primary_mail(self) -> None: """ If a user has passwords set, the 'mail' attribute will be considered indirectly verified. """ mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mail": mail, - "mailAliases": [{"email": mail, "verified": False}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mail": mail, + "mailAliases": [{"email": mail, "verified": False}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_user_with_indirectly_verified_primary_mail_and_explicit_primary_mail(self): + def test_user_with_indirectly_verified_primary_mail_and_explicit_primary_mail(self) -> None: """ If a user has manage to verify a mail address in the new style with the same address still set in old style mail property. Do not make old mail address primary if a primary all ready exists. """ old_mail = "yahoo@example.com" new_mail = "not_yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "lutol-bafim", - "mail": old_mail, - "mailAliases": [ - {"email": old_mail, "verified": True, "primary": False}, - {"email": new_mail, "verified": True, "primary": True}, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "lutol-bafim", + "mail": old_mail, + "mailAliases": [ + {"email": old_mail, "verified": True, "primary": False}, + {"email": new_mail, "verified": True, "primary": True}, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(new_mail, user.mail_addresses.primary.email) - def test_user_with_csrf_junk_in_mail_address(self): + def test_user_with_csrf_junk_in_mail_address(self) -> None: """ For a long time, Dashboard leaked CSRF tokens into the mail address dicts. """ mail = "yahoo@example.com" - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "mailAliases": [{"email": mail, "verified": True, "csrf": "6ae1d4e95305b72318a683883e70e3b8e302cd75"}], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), - "credential_id": str(ObjectId()), - "salt": "salt", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "mailAliases": [{"email": mail, "verified": True, "csrf": "6ae1d4e95305b72318a683883e70e3b8e302cd75"}], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-09-04T08:57:07.362000"), + "credential_id": str(ObjectId()), + "salt": "salt", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.mail_addresses.primary self.assertEqual(mail, user.mail_addresses.primary.email) - def test_to_dict(self): + def test_to_dict(self) -> None: """ Test that User objects can be recreated. """ @@ -414,7 +435,7 @@ def test_to_dict(self): d2 = u2.to_dict() self.assertEqual(d1, d2) - def test_modified_ts(self): + def test_modified_ts(self) -> None: """ Test the modified_ts property. """ @@ -428,148 +449,158 @@ def test_modified_ts(self): self.user1.modified_ts = datetime.utcnow() self.assertNotEqual(_time2, self.user1.modified_ts) - def test_two_unverified_non_primary_phones(self): + def test_two_unverified_non_primary_phones(self) -> None: """ Test that the first entry in the `phone' list is chosen as primary when none are verified. """ number1 = "+9112345678" number2 = "+9123456789" - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number1, - "primary": False, - "verified": False, - }, - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number2, - "primary": False, - "verified": False, - }, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number1, + "primary": False, + "verified": False, + }, + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number2, + "primary": False, + "verified": False, + }, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) self.assertEqual(user.phone_numbers.primary, None) - def test_two_non_primary_phones(self): + def test_two_non_primary_phones(self) -> None: """ Test that the first verified number is chosen as primary, if there is a verified number. """ number1 = "+9112345678" number2 = "+9123456789" - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number1, - "primary": False, - "verified": False, - }, - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": number2, - "primary": False, - "verified": True, - }, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number1, + "primary": False, + "verified": False, + }, + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": number2, + "primary": False, + "verified": True, + }, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) + assert user.phone_numbers.primary self.assertEqual(user.phone_numbers.primary.number, number2) - def test_primary_non_verified_phone(self): + def test_primary_non_verified_phone(self) -> None: """ Test that if a non verified phone number is primary, due to earlier error, then that primary flag is removed. """ - data = { - "_id": ObjectId(), - "displayName": "xxx yyy", - "eduPersonPrincipalName": "pohig-test", - "givenName": "xxx", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - { - "csrf": "47d42078719b8377db622c3ff85b94840b483c92", - "number": "+9112345678", - "primary": True, - "verified": False, - } - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "credential_id": str(ObjectId()), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - "preferredLanguage": "en", - "surname": "yyy", - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "displayName": "xxx yyy", + "eduPersonPrincipalName": "pohig-test", + "givenName": "xxx", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + { + "csrf": "47d42078719b8377db622c3ff85b94840b483c92", + "number": "+9112345678", + "primary": True, + "verified": False, + } + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "credential_id": str(ObjectId()), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + "preferredLanguage": "en", + "surname": "yyy", + } + ) user = User.from_dict(data) for number in user.phone_numbers.to_list(): self.assertEqual(number.is_primary, False) - def test_primary_non_verified_phone2(self): + def test_primary_non_verified_phone2(self) -> None: """ Test that if a non verified phone number is primary, due to earlier error, then that primary flag is removed. """ - data = { - "_id": ObjectId(), - "eduPersonPrincipalName": "pohig-test", - "mail": "test@gmail.com", - "mailAliases": [{"email": "test@gmail.com", "verified": True}], - "phone": [ - {"number": "+11111111111", "primary": True, "verified": False}, - {"number": "+22222222222", "primary": False, "verified": True}, - ], - "passwords": [ - { - "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), - "id": ObjectId(), - "salt": "$NDNv1H1$foo$32$32$", - "source": "dashboard", - } - ], - } + data = TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "pohig-test", + "mail": "test@gmail.com", + "mailAliases": [{"email": "test@gmail.com", "verified": True}], + "phone": [ + {"number": "+11111111111", "primary": True, "verified": False}, + {"number": "+22222222222", "primary": False, "verified": True}, + ], + "passwords": [ + { + "created_ts": datetime.fromisoformat("2014-06-29T17:52:37.830000"), + "id": ObjectId(), + "salt": "$NDNv1H1$foo$32$32$", + "source": "dashboard", + } + ], + } + ) user = User.from_dict(data) + assert user.phone_numbers.primary self.assertEqual(user.phone_numbers.primary.number, "+22222222222") - def test_user_tou_no_created_ts(self): + def test_user_tou_no_created_ts(self) -> None: """ Basic test for user ToU. """ @@ -587,7 +618,7 @@ def test_user_tou_no_created_ts(self): # attr set to True, and therefore the to_dict method will wipe out the created_ts key self.assertFalse(user.tou.has_accepted("1", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) - def test_user_tou(self): + def test_user_tou(self) -> None: """ Basic test for user ToU. """ @@ -605,7 +636,7 @@ def test_user_tou(self): self.assertTrue(user.tou.has_accepted("1", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) self.assertFalse(user.tou.has_accepted("2", reaccept_interval=94608000)) # reaccept_interval seconds (3 years) - def test_locked_identity_load(self): + def test_locked_identity_load(self) -> None: created_ts = datetime.fromisoformat("2013-09-02T10:23:25") locked_identity = { "created_by": "test", @@ -624,7 +655,7 @@ def test_locked_identity_load(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_load_legacy_format(self): + def test_locked_identity_load_legacy_format(self) -> None: created_ts = datetime.fromisoformat("2013-09-02T10:23:25") locked_identity = { "created_by": "test", @@ -642,7 +673,7 @@ def test_locked_identity_load_legacy_format(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_set(self): + def test_locked_identity_set(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -658,14 +689,14 @@ def test_locked_identity_set(self): assert user.locked_identity.nin.number == "197801012345" assert user.locked_identity.nin.is_verified is True - def test_locked_identity_set_not_verified(self): + def test_locked_identity_set_not_verified(self) -> None: locked_identity = {"created_by": "test", "identity_type": IdentityType.NIN.value, "number": "197801012345"} user = User.from_dict(self.data1) locked_nin = NinIdentity(number=locked_identity["number"], created_by=locked_identity["created_by"]) with pytest.raises(ValidationError): user.locked_identity.add(locked_nin) - def test_locked_identity_to_dict(self): + def test_locked_identity_to_dict(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -688,7 +719,7 @@ def test_locked_identity_to_dict(self): assert new_user.locked_identity.nin.number == "197801012345" assert new_user.locked_identity.nin.is_verified is True - def test_locked_identity_remove(self): + def test_locked_identity_remove(self) -> None: user = User.from_dict(self.data1) locked_nin = NinIdentity( number="197801012345", @@ -699,7 +730,7 @@ def test_locked_identity_remove(self): with self.assertRaises(EduIDUserDBError): user.locked_identity.remove(locked_nin.key) - def test_orcid(self): + def test_orcid(self) -> None: id_token = { "aud": ["APP_ID"], "auth_time": 1526389879, @@ -727,7 +758,8 @@ def test_orcid(self): user.orcid = orcid_element old_user = User.from_dict(user.to_dict()) - self.assertIsNotNone(old_user.orcid) + assert old_user + assert old_user.orcid self.assertIsInstance(old_user.orcid.created_by, str) self.assertIsInstance(old_user.orcid.created_ts, datetime) self.assertIsInstance(old_user.orcid.id, str) @@ -735,25 +767,26 @@ def test_orcid(self): self.assertIsInstance(old_user.orcid.oidc_authz.id_token, OidcIdToken) new_user = User.from_dict(user.to_dict()) - self.assertIsNotNone(new_user.orcid) + assert new_user + assert new_user.orcid self.assertIsInstance(new_user.orcid.created_by, str) self.assertIsInstance(new_user.orcid.created_ts, datetime) self.assertIsInstance(new_user.orcid.id, str) self.assertIsInstance(new_user.orcid.oidc_authz, OidcAuthorization) self.assertIsInstance(new_user.orcid.oidc_authz.id_token, OidcIdToken) - def test_profiles(self): + def test_profiles(self) -> None: self.assertIsNotNone(self.user1.profiles) self.assertEqual(self.user1.profiles.count, 0) self.assertIsNotNone(self.user2.profiles) self.assertEqual(self.user2.profiles.count, 1) - def test_user_verified_credentials(self): + def test_user_verified_credentials(self) -> None: ver = [x for x in self.user2.credentials.to_list() if x.is_verified] keys = [x.key for x in ver] self.assertEqual(keys, [_keyid("U2F SWAMID AL3" + "foo")]) - def test_user_unverified_credential(self): + def test_user_unverified_credential(self) -> None: cred = [x for x in self.user2.credentials.to_list() if x.is_verified][0] self.assertEqual(cred.proofing_method, CredentialProofingMethod.SWAMID_AL3_MFA) _dict1 = cred.to_dict() @@ -766,63 +799,67 @@ def test_user_unverified_credential(self): self.assertFalse("proofing_method" in _dict2) self.assertFalse("proofing_version" in _dict2) - def test_both_mobile_and_phone(self): + def test_both_mobile_and_phone(self) -> None: """Test user that has both 'mobile' and 'phone'""" phone = [ {"number": "+4673123", "primary": True, "verified": True}, {"created_by": "phone", "number": "+4670999", "primary": False, "verified": False}, ] user = User.from_dict( - data={ - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "passwords": [], - "mobile": [{"mobile": "+4673123", "primary": True, "verified": True}], - "phone": phone, - } + data=TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "passwords": [], + "mobile": [{"mobile": "+4673123", "primary": True, "verified": True}], + "phone": phone, + } + ) ) out = user.to_dict()["phone"] assert phone == out, "The phone objects differ when using both phone and mobile" - def test_both_sn_and_surname(self): + def test_both_sn_and_surname(self) -> None: """Test user that has both 'sn' and 'surname'""" user = User.from_dict( - data={ - "_id": ObjectId(), - "eduPersonPrincipalName": "test-test", - "passwords": [], - "surname": "Right", - "sn": "Wrong", - } + data=TUserDbDocument( + { + "_id": ObjectId(), + "eduPersonPrincipalName": "test-test", + "passwords": [], + "surname": "Right", + "sn": "Wrong", + } + ) ) self.assertEqual("Right", user.to_dict()["surname"]) - def test_terminated_user(self): + def test_terminated_user(self) -> None: data = self.user1.to_dict() data["terminated"] = utc_now() user = User.from_dict(data) assert user.terminated is not None assert isinstance(user.terminated, datetime) is True - def test_terminated_user_false(self): + def test_terminated_user_false(self) -> None: # users can have terminated set to False due to a bug in the past data = self.user1.to_dict() data["terminated"] = False user = User.from_dict(data) assert user.terminated is None - def test_rebuild_user1(self): + def test_rebuild_user1(self) -> None: data = self.user1.to_dict() new_user1 = User.from_dict(data) self.assertEqual(new_user1.eppn, "guvat-nalif") - def test_rebuild_user2(self): + def test_rebuild_user2(self) -> None: data = self.user2.to_dict() new_user2 = User.from_dict(data) self.assertEqual(new_user2.eppn, "birub-gagoz") - def test_mail_addresses_from_dict(self): + def test_mail_addresses_from_dict(self) -> None: """ Test that we get back a correct list of dicts for old-style userdb data. """ @@ -851,7 +888,7 @@ def test_mail_addresses_from_dict(self): assert to_dict_output == mailAliases_list - def test_phone_numbers_from_dict(self): + def test_phone_numbers_from_dict(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -867,7 +904,7 @@ def test_phone_numbers_from_dict(self): to_dict_result = phone_numbers.to_list_of_dicts() assert to_dict_result == phone_list - def test_passwords_from_dict(self): + def test_passwords_from_dict(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -902,7 +939,7 @@ def test_passwords_from_dict(self): assert to_dict_result == expected - def test_phone_numbers(self): + def test_phone_numbers(self) -> None: """ Test that we get back a dict identical to the one we put in for old-style userdb data. """ @@ -918,7 +955,7 @@ def test_phone_numbers(self): assert obtained == expected - def test_user_meta(self): + def test_user_meta(self) -> None: version = ObjectId() _utc_now = utc_now() user_dict = self.user1.to_dict() @@ -938,21 +975,21 @@ def test_user_meta(self): } assert user_dict2["meta"] == expected - def test_user_meta_version(self): + def test_user_meta_version(self) -> None: assert self.user1.meta.is_in_database is False assert self.user1.meta.version is None self.user1.meta.new_version() assert self.user1.meta.is_in_database is False assert isinstance(self.user1.meta.version, ObjectId) is True - def test_user_meta_modified_ts(self): + def test_user_meta_modified_ts(self) -> None: assert self.user1.meta.modified_ts is None # TODO: remove below check when removing user.modified_ts # verify that user.modified_ts is synced with meta.modified_ts self.user1.modified_ts = utc_now() assert self.user1.meta.modified_ts == self.user1.modified_ts - def test_letter_proofing_data_to_list(self): + def test_letter_proofing_data_to_list(self) -> None: letter_proofing = { "created_by": "eduid-idproofing-letter", "created_ts": datetime(2015, 12, 18, 12, 0, 46), @@ -981,14 +1018,14 @@ def test_letter_proofing_data_to_list(self): user = User.from_dict(user_dict) assert user.to_dict()["letter_proofing_data"] == [letter_proofing] - def test_nins_and_identities_on_user(self): + def test_nins_and_identities_on_user(self) -> None: user_dict = UserFixtures().mocked_user_standard.to_dict() assert user_dict["identities"] != [] user_dict = User.from_dict(user_dict).to_dict() assert user_dict.get("nins") is None assert user_dict.get("identities") is not None - def test_empty_nins_list(self): + def test_empty_nins_list(self) -> None: user_dict = UserFixtures().mocked_user_standard.to_dict() del user_dict["identities"] user_dict["nins"] = [] diff --git a/src/eduid/userdb/tests/test_userdb.py b/src/eduid/userdb/tests/test_userdb.py index d38940091..72f593206 100644 --- a/src/eduid/userdb/tests/test_userdb.py +++ b/src/eduid/userdb/tests/test_userdb.py @@ -17,16 +17,18 @@ class TestUserDB(MongoTestCase): - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: self.user = UserFixtures().mocked_user_standard super().setUp(am_users=[self.user], **kwargs) - def test_get_user_by_id(self): + def test_get_user_by_id(self) -> None: """Test get_user_by_id""" res = self.amdb.get_user_by_id(self.user.user_id) + assert res assert self.user.user_id == res.user_id res = self.amdb.get_user_by_id(str(self.user.user_id)) + assert res assert self.user.user_id == res.user_id res = self.amdb.get_user_by_id(str(bson.ObjectId())) @@ -35,7 +37,7 @@ def test_get_user_by_id(self): res = self.amdb.get_user_by_id("not-a-valid-object-id") assert res is None - def test_get_user_by_nin(self): + def test_get_user_by_nin(self) -> None: """Test get_user_by_nin""" test_user = self.amdb.get_user_by_id(self.user.user_id) assert test_user is not None @@ -46,25 +48,28 @@ def test_get_user_by_nin(self): assert res is not None assert test_user.given_name == res.given_name - def test_remove_user_by_id(self): + def test_remove_user_by_id(self) -> None: """Test removing a user from the database NOTE: remove_user_by_id() should be moved to SignupUserDb """ test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user + assert test_user.identities.nin res = self.amdb.get_users_by_nin(test_user.identities.nin.number) assert normalised_data(res[0].to_dict()) == normalised_data(test_user.to_dict()) self.amdb.remove_user_by_id(test_user.user_id) res = self.amdb.get_users_by_nin(test_user.identities.nin.number) assert res == [] - def test_get_user_by_eppn(self): + def test_get_user_by_eppn(self) -> None: """Test user lookup using eppn""" test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user res = self.amdb.get_user_by_eppn(test_user.eppn) assert test_user.user_id == res.user_id - def test_get_user_by_eppn_not_found(self): + def test_get_user_by_eppn_not_found(self) -> None: """Test user lookup using unknown""" with pytest.raises(UserDoesNotExist): self.amdb.get_user_by_eppn("abc123") @@ -73,7 +78,7 @@ def test_get_user_by_eppn_not_found(self): class UserMissingMeta(MongoTestCase): user: User - def setUp(self): + def setUp(self) -> None: # type: ignore[override] self.user = UserFixtures().mocked_user_standard super().setUp(am_users=[self.user]) @@ -90,7 +95,7 @@ def _remove_meta_from_user_in_db(self, user: User) -> None: user_doc.pop("meta") self.amdb._coll.replace_one({"_id": user.user_id}, user_doc) - def test_update_user_new(self): + def test_update_user_new(self) -> None: db_user = self.amdb.get_user_by_id(self.user.user_id) assert db_user is not None logger.debug(f"Loaded user.meta from database:\n{format_dict_for_debug(db_user.meta.dict())}") @@ -98,19 +103,20 @@ def test_update_user_new(self): db_user.given_name = "test" self.amdb.save(user=db_user) - def test_update_user_old(self): + def test_update_user_old(self) -> None: db_user = self.amdb.get_user_by_id(self.user.user_id) + assert db_user db_user.given_name = "test" self.amdb.save(user=db_user) class UpdateUser(MongoTestCase): - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: _users = UserFixtures() self.user = _users.mocked_user_standard super().setUp(am_users=[self.user, _users.mocked_user_standard_2], **kwargs) - def test_stale_user_meta_version(self): + def test_stale_user_meta_version(self) -> None: test_user = self.amdb.get_user_by_eppn(self.user.eppn) test_user.given_name = "new_given_name" test_user.meta.new_version() @@ -118,8 +124,9 @@ def test_stale_user_meta_version(self): with self.assertRaises(UserOutOfSync): self.amdb.save(test_user) - def test_ok(self): + def test_ok(self) -> None: test_user = self.amdb.get_user_by_id(self.user.user_id) + assert test_user test_user.given_name = "new_given_name" old_meta_version = test_user.meta.version @@ -129,13 +136,14 @@ def test_ok(self): assert res.success is True db_user = self.amdb.get_user_by_id(test_user.user_id) + assert db_user assert db_user.meta.version != old_meta_version assert db_user.modified_ts != old_modified_ts assert db_user.given_name == "new_given_name" class TestUserDB_mail(MongoTestCase): - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: super().setUp(*args, **kwargs) data1: TUserDbDocument = TUserDbDocument( { @@ -165,16 +173,19 @@ def setUp(self, *args: Any, **kwargs: Any): self.amdb.save(self.user1) self.amdb.save(self.user2) - def test_get_user_by_mail(self): + def test_get_user_by_mail(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.mail_addresses.primary res = self.amdb.get_user_by_mail(test_user.mail_addresses.primary.email) + assert res assert test_user.user_id == res.user_id - def test_get_user_by_mail_unknown(self): + def test_get_user_by_mail_unknown(self) -> None: """Test searching for unknown e-mail address""" assert self.amdb.get_user_by_mail("abc123@example.edu") is None - def test_get_user_by_mail_multiple(self): + def test_get_user_by_mail_multiple(self) -> None: res = self.amdb.get_users_by_mail("test@gmail.com") ids = [x.user_id for x in res] assert ids == [self.user1.user_id] @@ -185,7 +196,7 @@ def test_get_user_by_mail_multiple(self): class TestUserDB_phone(MongoTestCase): - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: super().setUp(*args, **kwargs) data1: TUserDbDocument = TUserDbDocument( { @@ -218,24 +229,28 @@ def setUp(self, *args: Any, **kwargs: Any): self.amdb.save(self.user1) self.amdb.save(self.user2) - def test_get_user_by_phone(self): + def test_get_user_by_phone(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.phone_numbers.primary res = self.amdb.get_user_by_phone(test_user.phone_numbers.primary.number) + assert res assert res.user_id == test_user.user_id res = self.amdb.get_user_by_phone("+22222222222") + assert res assert res.user_id == test_user.user_id assert self.amdb.get_user_by_phone("+33333333333") is None - res = self.amdb.get_users_by_phone("+33333333333", include_unconfirmed=True) - assert [x.user_id for x in res] == [self.user2.user_id] + res_list = self.amdb.get_users_by_phone("+33333333333", include_unconfirmed=True) + assert [x.user_id for x in res_list] == [self.user2.user_id] - def test_get_user_by_phone_unknown(self): + def test_get_user_by_phone_unknown(self) -> None: """Test searching for unknown e-phone address""" assert self.amdb.get_user_by_phone("abc123@example.edu") is None - def test_get_user_by_phone_multiple(self): + def test_get_user_by_phone_multiple(self) -> None: res = self.amdb.get_users_by_phone("+11111111111") ids = [x.user_id for x in res] assert ids == [self.user1.user_id] @@ -247,7 +262,7 @@ def test_get_user_by_phone_multiple(self): class TestUserDB_nin(MongoTestCase): # TODO: Keep for a while to make sure the conversion to identities work as expected - def setUp(self, *args: Any, **kwargs: Any): + def setUp(self, *args: Any, **kwargs: Any) -> None: super().setUp(*args, **kwargs) data1: TUserDbDocument = TUserDbDocument( { @@ -291,37 +306,44 @@ def setUp(self, *args: Any, **kwargs: Any): self.amdb.save(self.user2) self.amdb.save(self.user3) - def test_get_user_by_nin(self): + def test_get_user_by_nin(self) -> None: test_user = self.amdb.get_user_by_id(self.user1.user_id) + assert test_user + assert test_user.identities.nin res = self.amdb.get_user_by_nin(test_user.identities.nin.number) + assert res assert res.user_id == test_user.user_id, "alpha" res = self.amdb.get_user_by_nin("11111111111") + assert res assert res.user_id == test_user.user_id, "beta" res = self.amdb.get_user_by_nin("22222222222") + assert res assert res.user_id == self.user2.user_id, "gamma" assert self.amdb.get_user_by_nin("33333333333") is None, "delta" - res = self.amdb.get_users_by_nin("33333333333", include_unconfirmed=True) - assert [x.user_id for x in res] == [self.user3.user_id], "epsilon" + res_list = self.amdb.get_users_by_nin("33333333333", include_unconfirmed=True) + assert [x.user_id for x in res_list] == [self.user3.user_id], "epsilon" - def test_get_user_by_nin_unknown(self): + def test_get_user_by_nin_unknown(self) -> None: """Test searching for unknown e-nin address""" assert self.amdb.get_user_by_nin("77777777777") is None - def test_get_user_by_nin_multiple(self): + def test_get_user_by_nin_multiple(self) -> None: # create another user with nin 33333333333, this one verified - data4 = { - "_id": bson.ObjectId(), - "eduPersonPrincipalName": "nin-test4", - "mail": "anka@example.com", - "nins": [ - {"number": "33333333333", "primary": True, "verified": True}, - ], - "passwords": [signup_password.to_dict()], - } + data4 = TUserDbDocument( + { + "_id": bson.ObjectId(), + "eduPersonPrincipalName": "nin-test4", + "mail": "anka@example.com", + "nins": [ + {"number": "33333333333", "primary": True, "verified": True}, + ], + "passwords": [signup_password.to_dict()], + } + ) user4 = User.from_dict(data4) self.amdb.save(user4) diff --git a/src/eduid/userdb/tests/test_webauthn.py b/src/eduid/userdb/tests/test_webauthn.py index 3b790be4e..58513a11a 100644 --- a/src/eduid/userdb/tests/test_webauthn.py +++ b/src/eduid/userdb/tests/test_webauthn.py @@ -2,7 +2,7 @@ from hashlib import sha256 from unittest import TestCase -from eduid.userdb.credentials import CredentialList, CredentialProofingMethod +from eduid.userdb.credentials import CredentialList, CredentialProofingMethod, Webauthn __author__ = "lundberg" @@ -28,17 +28,19 @@ def _keyid(key: dict[str, str]): class TestWebauthn(TestCase): - def setUp(self): + def setUp(self) -> None: self.empty = CredentialList() self.one = CredentialList.from_list_of_dicts([_one_dict]) self.two = CredentialList.from_list_of_dicts([_one_dict, _two_dict]) self.three = CredentialList.from_list_of_dicts([_one_dict, _two_dict, _three_dict]) - def test_key(self): + def test_key(self) -> None: """ Test that the 'key' property (used by CredentialList) works for the credential. """ this = self.one.find(_keyid(_one_dict)) + assert this + assert isinstance(this, Webauthn) self.assertEqual( this.key, _keyid( @@ -49,7 +51,7 @@ def test_key(self): ), ) - def test_parse_cycle(self): + def test_parse_cycle(self) -> None: """ Tests that we output something we parsed back into the same thing we output. """ @@ -57,17 +59,20 @@ def test_parse_cycle(self): this_dict = this.to_list_of_dicts() self.assertEqual(CredentialList.from_list_of_dicts(this_dict).to_list_of_dicts(), this.to_list_of_dicts()) - def test_created_by(self): + def test_created_by(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.created_by = "unit test" self.assertEqual(this.created_by, "unit test") - def test_created_ts(self): + def test_created_ts(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this self.assertIsInstance(this.created_ts, datetime.datetime) - def test_proofing_method(self): + def test_proofing_method(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_method = CredentialProofingMethod.SWAMID_AL2_MFA_HI self.assertEqual(this.proofing_method, CredentialProofingMethod.SWAMID_AL2_MFA_HI) this.proofing_method = CredentialProofingMethod.SWAMID_AL3_MFA @@ -75,8 +80,9 @@ def test_proofing_method(self): this.proofing_method = None self.assertEqual(this.proofing_method, None) - def test_proofing_version(self): + def test_proofing_version(self) -> None: this = self.three.find(_keyid(_three_dict)) + assert this this.proofing_version = "TEST" self.assertEqual(this.proofing_version, "TEST") this.proofing_version = "TEST2" diff --git a/src/eduid/userdb/userdb.py b/src/eduid/userdb/userdb.py index de5d8fea6..1d80375d8 100644 --- a/src/eduid/userdb/userdb.py +++ b/src/eduid/userdb/userdb.py @@ -182,7 +182,7 @@ def get_users_by_nin(self, nin: str, include_unconfirmed: bool = False) -> list[ def get_users_by_identity( self, identity_type: IdentityType, key: str, value: str, include_unconfirmed: bool = False - ): + ) -> list[UserVar]: match = {"identity_type": identity_type.value, key: value, "verified": True} if include_unconfirmed: del match["verified"] diff --git a/src/eduid/userdb/util.py b/src/eduid/userdb/util.py index a319a2983..cd61d4292 100644 --- a/src/eduid/userdb/util.py +++ b/src/eduid/userdb/util.py @@ -14,13 +14,13 @@ class UTC(datetime.tzinfo): """UTC""" - def utcoffset(self, dt: datetime.datetime | None): + def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta: return datetime.timedelta(0) - def tzname(self, dt: datetime.datetime | None): + def tzname(self, dt: datetime.datetime | None) -> str: return "UTC" - def dst(self, dt: datetime.datetime | None): + def dst(self, dt: datetime.datetime | None) -> datetime.timedelta: return datetime.timedelta(0) diff --git a/src/eduid/vccs/client/tests/test_client.py b/src/eduid/vccs/client/tests/test_client.py index d85154618..712f691c2 100644 --- a/src/eduid/vccs/client/tests/test_client.py +++ b/src/eduid/vccs/client/tests/test_client.py @@ -42,7 +42,7 @@ def _get_random_bytes(self, num_bytes: int): class TestVCCSClient(unittest.TestCase): - def test_password_factor(self): + def test_password_factor(self) -> None: """ Test creating a VCCSPasswordFactor instance. """ @@ -57,7 +57,7 @@ def test_password_factor(self): }, ) - def test_utf8_password_factor(self): + def test_utf8_password_factor(self) -> None: """ Test creating a VCCSPasswordFactor instance. """ @@ -72,22 +72,22 @@ def test_utf8_password_factor(self): }, ) - def test_OATH_factor_auth(self): + def test_OATH_factor_auth(self) -> None: """ Test creating a VCCSOathFactor instance. """ aead = "aa" * 20 - o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, user_code=123456) self.assertEqual( o.to_dict("auth"), { "type": "oath-hotp", "credential_id": 4712, - "user_code": "123456", + "user_code": 123456, }, ) - def test_OATH_factor_add(self): + def test_OATH_factor_add(self) -> None: """ Test creating a VCCSOathFactor instance for an add_creds request. """ @@ -106,24 +106,24 @@ def test_OATH_factor_add(self): }, ) - def test_missing_parts_of_OATH_factor(self): + def test_missing_parts_of_OATH_factor(self) -> None: """ Test creating a VCCSOathFactor instance with missing parts. """ aead = "aa" * 20 - o = VCCSOathFactor("oath-hotp", 4712, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, user_code=123456) # missing AEAD with self.assertRaises(ValueError): o.to_dict("add_creds") - o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, key_handle=0x1234, user_code="123456") + o = VCCSOathFactor("oath-hotp", 4712, nonce="010203040506", aead=aead, key_handle=0x1234, user_code=123456) # with AEAD o should be OK self.assertEqual(type(o.to_dict("add_creds")), dict) # unknown to_dict 'action' should raise with self.assertRaises(ValueError): o.to_dict("bad_action") - def test_authenticate1(self): + def test_authenticate1(self) -> None: """ Test parsing of successful authentication response. """ @@ -137,7 +137,7 @@ def test_authenticate1(self): f = VCCSPasswordFactor("password", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertTrue(c.authenticate("ft@example.net", [f])) - def test_authenticate1_utf8(self): + def test_authenticate1_utf8(self) -> None: """ Test parsing of successful authentication response with a password in UTF-8. """ @@ -151,7 +151,7 @@ def test_authenticate1_utf8(self): f = VCCSPasswordFactor("passwordåäöхэж", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertTrue(c.authenticate("ft@example.net", [f])) - def test_authenticate2(self): + def test_authenticate2(self) -> None: """ Test unknown response version """ @@ -165,7 +165,7 @@ def test_authenticate2(self): with self.assertRaises(AssertionError): c.authenticate("ft@example.net", [f]) - def test_authenticate2_utf8(self): + def test_authenticate2_utf8(self) -> None: """ Test unknown response version with a password in UTF-8. """ @@ -179,7 +179,7 @@ def test_authenticate2_utf8(self): with self.assertRaises(AssertionError): c.authenticate("ft@example.net", [f]) - def test_add_creds1(self): + def test_add_creds1(self) -> None: """ Test parsing of successful add_creds response. """ @@ -207,7 +207,7 @@ def test_add_creds1(self): } self.assertEqual(expected, values) - def test_add_creds1_utf8(self): + def test_add_creds1_utf8(self) -> None: """ Test parsing of successful add_creds response with a password in UTF-8. """ @@ -235,7 +235,7 @@ def test_add_creds1_utf8(self): } self.assertEqual(expected, values) - def test_add_creds2(self): + def test_add_creds2(self) -> None: """ Test parsing of unsuccessful add_creds response. """ @@ -249,7 +249,7 @@ def test_add_creds2(self): f = VCCSPasswordFactor("password", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertFalse(c.add_credentials("ft@example.net", [f])) - def test_add_creds2_utf8(self): + def test_add_creds2_utf8(self) -> None: """ Test parsing of unsuccessful add_creds response with a password in UTF-8. """ @@ -263,7 +263,7 @@ def test_add_creds2_utf8(self): f = VCCSPasswordFactor("passwordåäöхэж", "4711", "$NDNv1H1$aaaaaaaaaaaaaaaa$12$32$") self.assertFalse(c.add_credentials("ft@example.net", [f])) - def test_revoke_creds1(self): + def test_revoke_creds1(self) -> None: """ Test parsing of unsuccessful revoke_creds response. """ @@ -277,24 +277,24 @@ def test_revoke_creds1(self): r = VCCSRevokeFactor("4712", "testing revoke", "foobar") self.assertFalse(c.revoke_credentials("ft@example.net", [r])) - def test_revoke_creds2(self): + def test_revoke_creds2(self) -> None: """ Test revocation reason/reference bad types. """ - FakeVCCSClient(None) + FakeVCCSClient("Fake response not used in test") with self.assertRaises(TypeError): - VCCSRevokeFactor(4712, 1234, "foobar") + VCCSRevokeFactor(4712, 1234, "foobar") # type: ignore[arg-type] with self.assertRaises(TypeError): - VCCSRevokeFactor(4712, "foobar", 2345) + VCCSRevokeFactor(4712, "foobar", 2345) # type: ignore[arg-type] - def test_unknown_salt_version(self): + def test_unknown_salt_version(self) -> None: """Test unknown salt version""" with self.assertRaises(ValueError): VCCSPasswordFactor("anything", "4711", "$NDNvFOO$aaaaaaaaaaaaaaaa$12$32$") - def test_generate_salt1(self): + def test_generate_salt1(self) -> None: """Test salt generation.""" f = VCCSPasswordFactor("anything", "4711") self.assertEqual(len(f.salt), 80) @@ -303,7 +303,7 @@ def test_generate_salt1(self): self.assertEqual(rounds, 32) self.assertEqual(len(random), length) - def test_generate_salt2(self): + def test_generate_salt2(self) -> None: """Test salt generation with fake RNG.""" f = FakeVCCSPasswordFactor("anything", "4711") diff --git a/src/eduid/vccs/server/endpoints/misc.py b/src/eduid/vccs/server/endpoints/misc.py index 3345c5836..210fa0e15 100644 --- a/src/eduid/vccs/server/endpoints/misc.py +++ b/src/eduid/vccs/server/endpoints/misc.py @@ -43,6 +43,6 @@ class HMACResponse(BaseModel): @misc_router.get("/hmac/{keyhandle}/{data}", response_model=HMACResponse) -async def hmac(request: Request, keyhandle: int, data: bytes): +async def hmac(request: Request, keyhandle: int, data: bytes) -> HMACResponse: hmac = await request.app.state.hasher.hmac_sha1(key_handle=keyhandle, data=data) return HMACResponse(keyhandle=keyhandle, hmac=hmac.hex()) diff --git a/src/eduid/vccs/server/hasher.py b/src/eduid/vccs/server/hasher.py index c87161c4d..f93770eb3 100644 --- a/src/eduid/vccs/server/hasher.py +++ b/src/eduid/vccs/server/hasher.py @@ -6,7 +6,7 @@ from binascii import unhexlify from collections.abc import Mapping from hashlib import sha1 -from typing import Any +from typing import Any, Literal import pyhsm import yaml @@ -20,10 +20,10 @@ class NoOpLock: def __init__(self): pass - async def acquire(self): + async def acquire(self) -> None: pass - async def release(self): + async def release(self) -> None: pass @@ -37,23 +37,23 @@ def unlock(self, password: str) -> None: def info(self) -> Any: raise NotImplementedError("Subclass should implement info") - async def hmac_sha1(self, key_handle: int | None, data: bytes): + async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: raise NotImplementedError("Subclass should implement safe_hmac_sha1") - def unsafe_hmac_sha1(self, key_handle: int | None, _data: bytes): + def unsafe_hmac_sha1(self, key_handle: int | None, _data: bytes) -> bytes: raise NotImplementedError("Subclass should implement hmac_sha1") - def load_temp_key(self, nonce: str, _key_handle: int, _aead: bytes): + def load_temp_key(self, nonce: str, _key_handle: int, _aead: bytes) -> bool: raise NotImplementedError("Subclass should implement load_temp_key") - async def safe_random(self, byte_count: int): + async def safe_random(self, byte_count: int) -> bytes: raise NotImplementedError("Subclass should implement safe_random") - async def lock_acquire(self): + async def lock_acquire(self) -> Literal[True] | None: return await self.lock.acquire() - async def lock_release(self): - return self.lock.release() + def lock_release(self) -> None: + self.lock.release() class VCCSYHSMHasher(VCCSHasher): @@ -78,14 +78,14 @@ async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: try: return self.unsafe_hmac_sha1(key_handle, data) finally: - await self.lock_release() + self.lock_release() def unsafe_hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: if key_handle is None: key_handle = pyhsm.defines.YSM_TEMP_KEY_HANDLE return self._yhsm.hmac_sha1(key_handle, data).get_hash() - def load_temp_key(self, nonce: str, key_handle: int, aead: bytes): + def load_temp_key(self, nonce: str, key_handle: int, aead: bytes) -> bool: return self._yhsm.load_temp_key(nonce, key_handle, aead) async def safe_random(self, byte_count: int) -> bytes: @@ -101,7 +101,7 @@ async def safe_random(self, byte_count: int) -> bytes: xored = bytes([a ^ b for (a, b) in zip(from_hsm, from_os)]) return xored finally: - await self.lock_release() + self.lock_release() class VCCSSoftHasher(VCCSHasher): @@ -135,7 +135,7 @@ async def hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: try: return self.unsafe_hmac_sha1(key_handle, data) finally: - await self.lock_release() + self.lock_release() def unsafe_hmac_sha1(self, key_handle: int | None, data: bytes) -> bytes: if key_handle is None: @@ -158,7 +158,9 @@ async def safe_random(self, byte_count: int) -> bytes: return os.urandom(byte_count) -def hasher_from_string(name: str, lock: Lock | NoOpLock | None = None, debug: bool = False): +def hasher_from_string( + name: str, lock: Lock | NoOpLock | None = None, debug: bool = False +) -> VCCSSoftHasher | VCCSYHSMHasher: """ Create a hasher instance from a name. Name can currently only be a name of a YubiHSM device, such as '/dev/ttyACM0'. diff --git a/src/eduid/vccs/server/log.py b/src/eduid/vccs/server/log.py index dce4acea2..ba96a5c14 100644 --- a/src/eduid/vccs/server/log.py +++ b/src/eduid/vccs/server/log.py @@ -1,11 +1,12 @@ import logging import sys +from loguru import Logger from loguru import logger as loguru_logger class InterceptHandler(logging.Handler): - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: # Get corresponding Loguru level if it exists level: str | int try: @@ -22,7 +23,7 @@ def emit(self, record: logging.LogRecord): loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) -def init_logging(): +def init_logging() -> Logger: loguru_logger.remove() fmt = ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <7} | {module: <11}:" diff --git a/src/eduid/vccs/server/password.py b/src/eduid/vccs/server/password.py index bbc8c20b0..c1ece43dd 100644 --- a/src/eduid/vccs/server/password.py +++ b/src/eduid/vccs/server/password.py @@ -10,7 +10,7 @@ async def authenticate_password( cred: PasswordCredential, factor: RequestFactor, user_id: str, hasher: VCCSYHSMHasher, kdf: NDNKDF -): +) -> bool: res = False H2 = await calculate_cred_hash(user_id=user_id, H1=factor.H1, cred=cred, hasher=hasher, kdf=kdf) # XXX need to log successful login in credential_store to be able to ban diff --git a/src/eduid/vccs/server/run.py b/src/eduid/vccs/server/run.py index c1bc04273..6fa32a892 100644 --- a/src/eduid/vccs/server/run.py +++ b/src/eduid/vccs/server/run.py @@ -51,7 +51,7 @@ def __init__(self, test_config: Mapping[str, Any] | None = None) -> None: @app.on_event("startup") -async def startup_event(): +async def startup_event() -> None: """ Uvicorn mucks with the logging config on startup, particularly the access log. Rein it in. """ @@ -75,7 +75,7 @@ async def startup_event(): @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: request.app.logger.warning(f"Failed parsing request: {exc}") return JSONResponse({"errors": exc.errors()}, status_code=HTTP_422_UNPROCESSABLE_ENTITY) diff --git a/src/eduid/vccs/server/tests/test_db.py b/src/eduid/vccs/server/tests/test_db.py index bd7493863..120cbeefe 100644 --- a/src/eduid/vccs/server/tests/test_db.py +++ b/src/eduid/vccs/server/tests/test_db.py @@ -6,7 +6,7 @@ class TestCredential(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.data = { "_id": ObjectId("54042b7a9b3f2299bb9d5546"), "credential": { @@ -24,11 +24,11 @@ def setUp(self): "revision": 1, } - def test_from_dict(self): + def test_from_dict(self) -> None: cred = PasswordCredential.from_dict(self.data) assert cred.key_handle == 8192 - def test_to_dict_from_dict(self): + def test_to_dict_from_dict(self) -> None: cred1 = PasswordCredential.from_dict(self.data) cred2 = PasswordCredential.from_dict(cred1.to_dict()) assert cred1.to_dict() == cred2.to_dict() diff --git a/src/eduid/webapp/authn/tests/test_authn.py b/src/eduid/webapp/authn/tests/test_authn.py index 23af76efb..d9bca97e5 100644 --- a/src/eduid/webapp/authn/tests/test_authn.py +++ b/src/eduid/webapp/authn/tests/test_authn.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from flask import Blueprint +from flask import Blueprint, Response from saml2.s_utils import deflate_and_base64_encode from werkzeug.exceptions import NotFound from werkzeug.http import dump_cookie @@ -251,22 +251,24 @@ class AuthnAPITestCase(AuthnAPITestBase): app: AuthnApp - def setUp(self, **kwargs: Any): # type: ignore[override] + def setUp(self, **kwargs: Any) -> None: # type: ignore[override] super().setUp(users=["hubba-bubba", "hubba-fooo"], **kwargs) - def test_login_authn(self): + def test_login_authn(self) -> None: self.authn("/authenticate", FrontendAction.LOGIN) - def test_chpass_authn(self): + def test_chpass_authn(self) -> None: self.authn("/authenticate", FrontendAction.CHANGE_PW_AUTHN) - def test_terminate_authn(self): + def test_terminate_authn(self) -> None: self.authn("/authenticate", FrontendAction.TERMINATE_ACCOUNT_AUTHN) - def test_login_assertion_consumer_service(self): - for accr in EduidAuthnContextClass: - if accr == EduidAuthnContextClass.NOT_IMPLEMENTED: + def test_login_assertion_consumer_service(self) -> None: + for context_class in EduidAuthnContextClass: + if context_class == EduidAuthnContextClass.NOT_IMPLEMENTED: accr = None + else: + accr = context_class eppn = "hubba-bubba" res = self.acs("/authenticate", eppn, frontend_action=FrontendAction.LOGIN, accr=accr) assert res.session.common.eppn == "hubba-bubba" @@ -275,7 +277,7 @@ def test_login_assertion_consumer_service(self): if accr: assert authn.asserted_authn_ctx == accr.value - def test_assertion_consumer_service(self): + def test_assertion_consumer_service(self) -> None: actions = [FrontendAction.LOGIN, FrontendAction.CHANGE_PW_AUTHN, FrontendAction.TERMINATE_ACCOUNT_AUTHN] for action in actions: res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=action) @@ -287,7 +289,7 @@ def test_assertion_consumer_service(self): age = utc_now() - authn.authn_instant assert 10 < age.total_seconds() < 15 - def test_frontend_state(self): + def test_frontend_state(self) -> None: eppn = "hubba-bubba" self.acs("/authenticate", eppn, FrontendAction.REMOVE_SECURITY_KEY_AUTHN, frontend_state="key_id_to_remove") @@ -343,12 +345,12 @@ def load_app(self, test_config: Mapping[str, Any]) -> AuthnTestApp: config = load_config(typ=AuthnConfig, app_name="testing", ns="webapp", test_config=test_config) return AuthnTestApp(config) - def test_no_cookie(self): + def test_no_cookie(self) -> None: with self.app.test_client() as c: resp = c.get("/") self.assertEqual(resp.status_code, 401) - def test_cookie(self): + def test_cookie(self) -> None: sessid = "fb1f42420b0109020203325d750185673df252de388932a3957f522a6c43aa47" self.redis_instance.conn.set(sessid, json.dumps({"v1": {"id": "0"}})) @@ -361,7 +363,7 @@ class NoAuthnAPITestCase(EduidAPITestCase): app: AuthnTestApp - def setUp(self): + def setUp(self) -> None: # type: ignore[override] super().setUp() test_views = Blueprint("testing", __name__) @@ -404,17 +406,17 @@ def load_app(self, test_config: Mapping[str, Any]) -> AuthnTestApp: config = load_config(typ=AuthnConfig, app_name="testing", ns="webapp", test_config=test_config) return AuthnTestApp(config) - def test_no_authn(self): + def test_no_authn(self) -> None: with self.app.test_client() as c: resp = c.get("/test") self.assertEqual(resp.status_code, 200) - def test_authn(self): + def test_authn(self) -> None: with self.app.test_client() as c: resp = c.get("/test2") self.assertEqual(resp.status_code, 401) - def test_no_authn_util(self): + def test_no_authn_util(self) -> None: no_authn_urls_before = [path for path in self.app.conf.no_authn_urls] no_authn_path = "/test3" no_authn_views(self.app.conf, [no_authn_path]) @@ -426,33 +428,35 @@ def test_no_authn_util(self): class LogoutRequestTests(AuthnAPITestBase): - def test_metadataview(self): + def test_metadataview(self) -> None: with self.app.test_client() as c: response = c.get("/saml2-metadata") self.assertEqual(response.status, "200 OK") - def test_logout_nologgedin(self): + def test_logout_nologgedin(self) -> None: eppn = "hubba-bubba" with self.app.test_request_context("/logout", method="GET"): # eppn is set in the IdP session.common.eppn = eppn response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn(self.app.conf.saml2_logout_redirect_url, response.headers["Location"]) - def test_logout_loggedin(self): + def test_logout_loggedin(self) -> None: res = self.acs(url="/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(res.session.meta.cookie_val) with self.app.test_request_context("/logout", method="GET", headers={"Cookie": cookie}): response = self.app.dispatch_request() + assert isinstance(response, Response) logger.debug(f"Test called /logout, response {response}") self.assertEqual(response.status, "302 FOUND") self.assertIn( "https://idp.example.com/simplesaml/saml2/idp/SingleLogoutService.php", response.headers["location"] ) - def test_logout_service_startingSP(self): + def test_logout_service_startingSP(self) -> None: session_id = self.start_authenticate(eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(session_id) @@ -466,11 +470,12 @@ def test_logout_service_startingSP(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) - def test_logout_service_startingSP_already_logout(self): + def test_logout_service_startingSP_already_logout(self) -> None: session_id = self.start_authenticate(eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) with self.app.test_request_context( @@ -482,11 +487,12 @@ def test_logout_service_startingSP_already_logout(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) - def test_logout_service_startingIDP(self): + def test_logout_service_startingIDP(self) -> None: res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) cookie = self.dump_session_cookie(res.session.meta.cookie_val) @@ -500,6 +506,7 @@ def test_logout_service_startingIDP(self): }, ): response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") assert ( @@ -507,7 +514,7 @@ def test_logout_service_startingIDP(self): in response.location ) - def test_logout_service_startingIDP_no_subject_id(self): + def test_logout_service_startingIDP_no_subject_id(self) -> None: eppn = "hubba-bubba" res = self.acs("/authenticate", eppn=self.test_user.eppn, frontend_action=FrontendAction.LOGIN) session_id = res.session.meta.cookie_val @@ -540,6 +547,7 @@ def test_logout_service_startingIDP_no_subject_id(self): session.authn.name_id = None session.persist() # Explicit session.persist is needed when working within a test_request_context response = self.app.dispatch_request() + assert isinstance(response, Response) self.assertEqual(response.status, "302 FOUND") self.assertIn("testing-relay-state", response.location) diff --git a/src/eduid/webapp/bankid/tests/test_app.py b/src/eduid/webapp/bankid/tests/test_app.py index a4a7f7981..5aed3da63 100644 --- a/src/eduid/webapp/bankid/tests/test_app.py +++ b/src/eduid/webapp/bankid/tests/test_app.py @@ -463,7 +463,7 @@ def _call_endpoint_and_saml_acs( method=method, ) - def test_authenticate(self): + def test_authenticate(self) -> None: response = self.browser.get("/") self.assertEqual(response.status_code, 401) with self.session_cookie(self.browser, self.test_user.eppn) as browser: @@ -471,7 +471,7 @@ def test_authenticate(self): self._check_success_response(response, type_="GET_BANKID_SUCCESS") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_u2f_token_verify(self, mock_request_user_sync: MagicMock): + def test_u2f_token_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -491,7 +491,7 @@ def test_u2f_token_verify(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, token_verified=True, num_proofings=1) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock): + def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -511,7 +511,7 @@ def test_webauthn_token_verify(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, token_verified=True, num_proofings=1) - def test_mfa_token_verify_wrong_verified_nin(self): + def test_mfa_token_verify_wrong_verified_nin(self) -> None: eppn = self.test_user.eppn nin = self.test_user_wrong_nin credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -532,7 +532,7 @@ def test_mfa_token_verify_wrong_verified_nin(self): self._verify_user_parameters(eppn, identity=nin, identity_present=False) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMock): + def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -556,7 +556,7 @@ def test_mfa_token_verify_no_verified_nin(self, mock_request_user_sync: MagicMoc eppn, token_verified=True, num_proofings=2, identity_present=True, identity=nin, identity_verified=True ) - def test_mfa_token_verify_no_mfa_login(self): + def test_mfa_token_verify_no_mfa_login(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -581,7 +581,7 @@ def test_mfa_token_verify_no_mfa_login(self): ) self._verify_user_parameters(eppn) - def test_mfa_token_verify_no_mfa_token_in_session(self): + def test_mfa_token_verify_no_mfa_token_in_session(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "webauthn") @@ -600,7 +600,7 @@ def test_mfa_token_verify_no_mfa_token_in_session(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_aborted_auth(self): + def test_mfa_token_verify_aborted_auth(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -619,7 +619,7 @@ def test_mfa_token_verify_aborted_auth(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_cancel_auth(self): + def test_mfa_token_verify_cancel_auth(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "webauthn") @@ -640,7 +640,7 @@ def test_mfa_token_verify_cancel_auth(self): self._verify_user_parameters(eppn) - def test_mfa_token_verify_auth_fail(self): + def test_mfa_token_verify_auth_fail(self) -> None: eppn = self.test_user.eppn credential = self.add_security_key_to_user(eppn, "test", "u2f") @@ -663,7 +663,7 @@ def test_mfa_token_verify_auth_fail(self): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock): + def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -688,7 +688,7 @@ def test_webauthn_token_verify_backdoor(self, mock_request_user_sync: MagicMock) self._verify_user_parameters(eppn, identity=nin, identity_verified=True, token_verified=True, num_proofings=2) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify(self, mock_request_user_sync: MagicMock): + def test_nin_verify(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -719,7 +719,7 @@ def test_nin_verify(self, mock_request_user_sync: MagicMock): assert doc["surname"] == "Älm" @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login(self, mock_request_user_sync: MagicMock): + def test_mfa_login(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_user.eppn @@ -735,7 +735,7 @@ def test_mfa_login(self, mock_request_user_sync: MagicMock): self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=True, num_proofings=0) - def test_mfa_login_no_nin(self): + def test_mfa_login_no_nin(self) -> None: eppn = self.test_unverified_user_eppn self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=False, token_verified=False) @@ -751,7 +751,7 @@ def test_mfa_login_no_nin(self): self._verify_user_parameters(eppn, num_mfa_tokens=0, identity_verified=False, num_proofings=0) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock): + def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -777,7 +777,7 @@ def test_mfa_login_unverified_nin(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock): + def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -805,7 +805,7 @@ def test_mfa_login_backdoor(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_backdoor(self, mock_request_user_sync: Any): + def test_nin_verify_backdoor(self, mock_request_user_sync: Any) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -828,7 +828,7 @@ def test_nin_verify_backdoor(self, mock_request_user_sync: Any): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock): + def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -856,7 +856,7 @@ def test_nin_verify_no_backdoor_in_pro(self, mock_request_user_sync: MagicMock): @unittest.skip("No support for magic cookie yet") @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: MagicMock): + def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync eppn = self.test_unverified_user_eppn @@ -883,7 +883,7 @@ def test_nin_verify_no_backdoor_misconfigured(self, mock_request_user_sync: Magi eppn, identity=self.test_user_nin, num_mfa_tokens=0, num_proofings=1, identity_verified=True ) - def test_nin_verify_already_verified(self): + def test_nin_verify_already_verified(self) -> None: # Verify that the test user has a verified NIN in the database already eppn = self.test_user.eppn nin = self.test_user_nin @@ -902,7 +902,7 @@ def test_nin_verify_already_verified(self): ) @patch("eduid.common.rpc.am_relay.AmRelay.request_user_sync") - def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMock): + def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMock) -> None: mock_request_user_sync.side_effect = self.request_user_sync user = self.app.central_userdb.get_user_by_eppn(self.test_user.eppn) @@ -927,7 +927,7 @@ def test_mfa_authentication_verified_user(self, mock_request_user_sync: MagicMoc cred = _creds[0] assert cred.level in self.app.conf.bankid_required_loa - def test_mfa_authentication_too_old_authn_instant(self): + def test_mfa_authentication_too_old_authn_instant(self) -> None: self.reauthn( endpoint="/mfa-authenticate", frontend_action=FrontendAction.LOGIN_MFA_AUTHN, @@ -936,7 +936,7 @@ def test_mfa_authentication_too_old_authn_instant(self): expect_error=True, ) - def test_mfa_authentication_wrong_nin(self): + def test_mfa_authentication_wrong_nin(self) -> None: user = self.app.central_userdb.get_user_by_eppn(self.test_user_eppn) assert user.identities.nin is not None assert user.identities.nin.is_verified is True, "User was expected to have a verified NIN" diff --git a/src/eduid/webapp/common/api/app.py b/src/eduid/webapp/common/api/app.py index 6e361ffa3..55ffa49f7 100644 --- a/src/eduid/webapp/common/api/app.py +++ b/src/eduid/webapp/common/api/app.py @@ -40,6 +40,7 @@ if TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication + from werkzeug.middleware.profiler import ProfilerMiddleware DEBUG = os.environ.get("EDUID_APP_DEBUG", False) if DEBUG: @@ -192,7 +193,7 @@ def init_status_views(app: EduIDBaseApp, config: EduIDBaseAppConfig) -> None: return None -def init_app_profiling(app: WSGIApplication, config: EduIDBaseAppConfig): +def init_app_profiling(app: WSGIApplication, config: EduIDBaseAppConfig) -> ProfilerMiddleware: """ Setup profiling middleware for any app. """ diff --git a/src/eduid/webapp/common/api/debug.py b/src/eduid/webapp/common/api/debug.py index dbf1e7cff..c043c5dcd 100644 --- a/src/eduid/webapp/common/api/debug.py +++ b/src/eduid/webapp/common/api/debug.py @@ -29,7 +29,7 @@ def log_response(status: str, headers: list[tuple[str, str]], *args: Any): return self._app(environ, log_response) -def log_endpoints(app: Flask): +def log_endpoints(app: Flask) -> None: output: list[str] = [] with app.app_context(): for rule in app.url_map.iter_rules(): @@ -46,7 +46,7 @@ def log_endpoints(app: Flask): pprint.pprint(("ENDPOINT", line), stream=sys.stderr) -def dump_config(app: Flask): +def dump_config(app: Flask) -> None: pprint.pprint(("CONFIGURATION", "app.config"), stream=sys.stderr) try: config_items = asdict(app.config).items() # type: ignore[call-overload] @@ -57,7 +57,7 @@ def dump_config(app: Flask): pprint.pprint((key, value), stream=sys.stderr) -def init_app_debug(app: Flask): +def init_app_debug(app: Flask) -> Flask: app.wsgi_app = LoggingMiddleware(app.wsgi_app) # type: ignore[method-assign] dump_config(app) log_endpoints(app) diff --git a/src/eduid/webapp/common/api/exceptions.py b/src/eduid/webapp/common/api/exceptions.py index 60385395f..75faa9b92 100644 --- a/src/eduid/webapp/common/api/exceptions.py +++ b/src/eduid/webapp/common/api/exceptions.py @@ -73,7 +73,7 @@ def __init__(self, state: "ResetPasswordEmailState"): self.state = state -def init_exception_handlers(app: Flask): +def init_exception_handlers(app: Flask) -> Flask: # Init error handler for raised exceptions @app.errorhandler(400) def _handle_flask_http_exception(error: HTTPException): @@ -88,7 +88,7 @@ def _handle_flask_http_exception(error: HTTPException): return app -def init_sentry(app: Flask): +def init_sentry(app: Flask) -> Flask: if app.config.get("SENTRY_DSN"): try: from raven.contrib.flask import Sentry diff --git a/src/eduid/webapp/common/api/helpers.py b/src/eduid/webapp/common/api/helpers.py index ca99e86ae..bc89963a6 100644 --- a/src/eduid/webapp/common/api/helpers.py +++ b/src/eduid/webapp/common/api/helpers.py @@ -291,7 +291,7 @@ def send_mail( app: EduIDBaseApp, context: dict[str, Any] | None = None, reference: str | None = None, -): +) -> None: """ :param subject: subject text :param to_addresses: email addresses for the to field diff --git a/src/eduid/webapp/common/api/messages.py b/src/eduid/webapp/common/api/messages.py index 156792f12..2e3133a7d 100644 --- a/src/eduid/webapp/common/api/messages.py +++ b/src/eduid/webapp/common/api/messages.py @@ -146,7 +146,7 @@ def _make_payload( return res -def make_query_string(msg: TranslatableMsg, error: bool = True): +def make_query_string(msg: TranslatableMsg, error: bool = True) -> str: """ Make a query string to send a translatable message to the front in the URL of a GET request. diff --git a/src/eduid/webapp/common/api/request.py b/src/eduid/webapp/common/api/request.py index 91b8f9e16..40d3ee45c 100644 --- a/src/eduid/webapp/common/api/request.py +++ b/src/eduid/webapp/common/api/request.py @@ -15,7 +15,7 @@ """ import logging -from collections.abc import Callable +from collections.abc import Callable, Iterator from typing import Any, AnyStr from flask import abort @@ -38,7 +38,7 @@ def sanitize_input( untrusted_text: AnyStr, content_type: str | None = None, strip_characters: bool = False, - ): + ) -> str: try: return super().sanitize_input( untrusted_text=untrusted_text, content_type=content_type, strip_characters=strip_characters @@ -65,7 +65,7 @@ def __getitem__(self, key: Any): value = super().__getitem__(key) return self.sanitize_input(value) - def getlist(self, key: Any, type: Callable[[Any], Any] | None = None): + def getlist(self, key: Any, type: Callable[[Any], Any] | None = None) -> list: """ Return the list of items for a given key. If that key is not in the `MultiDict`, the return value will be an empty list. Just as `get` @@ -82,7 +82,7 @@ def getlist(self, key: Any, type: Callable[[Any], Any] | None = None): value_list = super().getlist(key, type=type) return [self.sanitize_input(v) for v in value_list] - def items(self, multi: bool = False): + def items(self, multi: bool = False) -> Iterator[tuple[Any, str]]: # type: ignore[override] """ Return an iterator of ``(key, value)`` pairs. @@ -99,7 +99,7 @@ def items(self, multi: bool = False): else: yield key, values[0] - def lists(self): + def lists(self) -> Iterator[tuple[Any, list[str]]]: """Return a list of ``(key, values)`` pairs, where values is the list of all values associated with the key.""" @@ -107,14 +107,14 @@ def lists(self): values = [self.sanitize_input(v) for v in values] yield key, values - def values(self): + def values(self) -> Iterator[str]: # type: ignore[override] """ Returns an iterator of the first value on every key's value list. """ for values in dict.values(self): yield self.sanitize_input(values[0]) - def listvalues(self): + def listvalues(self) -> Iterator[Iterator[Any]]: # type: ignore[override] """ Return an iterator of all values associated with a key. Zipping :meth:`keys` and this is the same as calling :meth:`lists`: @@ -126,7 +126,7 @@ def listvalues(self): for values in dict.values(self): yield (self.sanitize_input(v) for v in values) - def to_dict(self, flat: bool = True): + def to_dict(self, flat: bool = True) -> dict | dict[Any, list[str]]: """Return the contents as regular dict. If `flat` is `True` the returned dict will only have the first item present, if `flat` is `False` all values will be returned as lists. @@ -175,20 +175,20 @@ def get(self, key: str, default: str | None = None, type: type | None = None) -> :rtype: object """ try: - val = self.sanitize_input(self[key]) + val: Any = self.sanitize_input(self[key]) if type is not None: val = type(val) except (KeyError, ValueError): val = default return val - def values(self): + def values(self) -> list[str]: # type: ignore[override] """ sanitized values """ return [self.sanitize_input(v) for v in super(ImmutableTypeConversionDict, self).values()] - def items(self): + def items(self) -> list[tuple[Any, str]]: # type: ignore[override] """ Sanitized items """ @@ -237,7 +237,7 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.headers = SanitizedEnvironHeaders(environ=self.environ) - def get_data(self, *args: Any, **kwargs: Any): + def get_data(self, *args: Any, **kwargs: Any) -> str: # type: ignore[override] text = super().get_data(*args, **kwargs) if text: text = self.sanitize_input(untrusted_text=text, content_type=self.mimetype) diff --git a/src/eduid/webapp/common/api/schemas/csrf.py b/src/eduid/webapp/common/api/schemas/csrf.py index c8893dc95..669f36d4e 100644 --- a/src/eduid/webapp/common/api/schemas/csrf.py +++ b/src/eduid/webapp/common/api/schemas/csrf.py @@ -16,7 +16,7 @@ class CSRFRequestMixin(Schema): csrf_token = fields.String(required=True) @validates("csrf_token") - def validate_csrf_token(self, value: str, **kwargs: Any): + def validate_csrf_token(self, value: str, **kwargs: Any) -> None: custom_header = request.headers.get("X-Requested-With") if custom_header != "XMLHttpRequest": # TODO: move value to config current_app.logger.error("CSRF check: missing custom X-Requested-With header") @@ -26,7 +26,7 @@ def validate_csrf_token(self, value: str, **kwargs: Any): logger.debug(f"Validated CSRF token in session: {session.get_csrf_token()}") @post_load - def post_processing(self, in_data: Any, **kwargs: Any): + def post_processing(self, in_data: Any, **kwargs: Any) -> Any: # Remove token from data forwarded to views in_data = self.remove_csrf_token(in_data) return in_data @@ -41,7 +41,7 @@ class CSRFResponseMixin(Schema): csrf_token = fields.String(required=True) @pre_dump - def get_csrf_token(self, out_data: Any, **kwargs: Any): + def get_csrf_token(self, out_data: Any, **kwargs: Any) -> Any: # Generate a new csrf token for every response out_data["csrf_token"] = session.new_csrf_token() logger.debug(f'Generated new CSRF token in CSRFResponseMixin: {out_data["csrf_token"]}') diff --git a/src/eduid/webapp/common/api/schemas/password.py b/src/eduid/webapp/common/api/schemas/password.py index 11b5fa71d..c9daaa4e6 100644 --- a/src/eduid/webapp/common/api/schemas/password.py +++ b/src/eduid/webapp/common/api/schemas/password.py @@ -19,7 +19,7 @@ def __init__(self, *args: Any, **kwargs: Any): self.Meta.min_score = kwargs.pop("min_score") super().__init__(*args, **kwargs) - def validate_password(self, password: str, **kwargs: Any): + def validate_password(self, password: str, **kwargs: Any) -> None: """ :param password: New password diff --git a/src/eduid/webapp/common/api/schemas/validators.py b/src/eduid/webapp/common/api/schemas/validators.py index 700e27108..fb36b6f5a 100644 --- a/src/eduid/webapp/common/api/schemas/validators.py +++ b/src/eduid/webapp/common/api/schemas/validators.py @@ -7,7 +7,7 @@ __author__ = "lundberg" -def validate_nin(nin: str, **kwargs: Any): +def validate_nin(nin: str, **kwargs: Any) -> bool: """ :param nin: National Identity Number :type nin: string_types @@ -20,7 +20,7 @@ def validate_nin(nin: str, **kwargs: Any): raise ValidationError("nin needs to be formatted as 18|19|20yymmddxxxx") -def validate_email(email: str, **kwargs: Any): +def validate_email(email: str, **kwargs: Any) -> bool: """ :param email: E-mail address :type email: string_types diff --git a/src/eduid/webapp/common/api/testing.py b/src/eduid/webapp/common/api/testing.py index b9e484ba3..14281f998 100644 --- a/src/eduid/webapp/common/api/testing.py +++ b/src/eduid/webapp/common/api/testing.py @@ -133,7 +133,7 @@ def setUp( # type: ignore[override] self.redis_instance = RedisTemporaryInstance.get_instance() # settings config = deepcopy(TEST_CONFIG) - self.settings = self.update_config(config) + self.settings: dict[str, Any] = self.update_config(config) self.settings["redis_config"] = RedisConfig(host="localhost", port=self.redis_instance.port) assert isinstance(self.tmp_db, MongoTemporaryInstance) # please mypy self.settings["mongo_uri"] = self.tmp_db.uri @@ -154,7 +154,7 @@ def setUp( # type: ignore[override] logging.info(f"Copying test-user {self.test_user} to private_userdb {_private_userdb}") _private_userdb.save(_private_userdb.user_from_dict(data=data)) - def tearDown(self): + def tearDown(self) -> None: try: # Reset anything that looks like a BaseDB, for the next test class. for this in vars(self.app).values(): @@ -323,7 +323,7 @@ def set_authn_action( finish_url: str | None = None, force_mfa: bool = False, credentials_used: list[ElementKey] | None = None, - ): + ) -> None: if not finish_url: finish_url = "https://example.com/ext-return/{app_name}/{authn_id}" diff --git a/src/eduid/webapp/common/api/tests/test_backdoor.py b/src/eduid/webapp/common/api/tests/test_backdoor.py index 45e6b20cf..60d7eed54 100644 --- a/src/eduid/webapp/common/api/tests/test_backdoor.py +++ b/src/eduid/webapp/common/api/tests/test_backdoor.py @@ -14,8 +14,9 @@ @test_views.route("/get-code", methods=["GET"]) -def get_code(): +def get_code() -> str: current_app.logger.info("Endpoint get_code called") + assert isinstance(current_app, BackdoorTestApp) try: if check_magic_cookie(current_app.conf): eppn = request.args.get("eppn") @@ -79,13 +80,13 @@ def load_app(self, config: Mapping[str, Any]) -> BackdoorTestApp: app.session_interface = SessionFactory(app.conf) return app - def test_backdoor_get_code(self): + def test_backdoor_get_code(self) -> None: """""" with self.session_cookie_and_magic_cookie_anon(self.browser) as client: response = client.get(self.test_get_url) assert response.data == b"dummy-code-for-pepin-pepon" - def test_no_backdoor_in_pro(self): + def test_no_backdoor_in_pro(self) -> None: """""" self.app.conf.environment = EduidEnvironment("production") @@ -93,19 +94,19 @@ def test_no_backdoor_in_pro(self): response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_backdoor_without_cookie(self): + def test_no_backdoor_without_cookie(self) -> None: """""" with self.session_cookie_anon(self.browser) as client: response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_wrong_cookie_no_backdoor(self): + def test_wrong_cookie_no_backdoor(self) -> None: """""" with self.session_cookie_and_magic_cookie_anon(self.browser, magic_cookie_value="no-magic") as client: response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_magic_cookie_no_backdoor(self): + def test_no_magic_cookie_no_backdoor(self) -> None: """""" self.app.conf.magic_cookie = "" @@ -113,7 +114,7 @@ def test_no_magic_cookie_no_backdoor(self): response = client.get(self.test_get_url) self.assertEqual(response.status_code, 400) - def test_no_magic_cookie_name_no_backdoor(self): + def test_no_magic_cookie_name_no_backdoor(self) -> None: """""" self.app.conf.magic_cookie_name = "" diff --git a/src/eduid/webapp/common/api/tests/test_decorators.py b/src/eduid/webapp/common/api/tests/test_decorators.py index cc9218896..49097368b 100644 --- a/src/eduid/webapp/common/api/tests/test_decorators.py +++ b/src/eduid/webapp/common/api/tests/test_decorators.py @@ -49,7 +49,7 @@ def load_app(self, config: Mapping[str, Any]) -> DecoratorTestApp: app.session_interface = SessionFactory(app.conf) return app - def test_success_message(self): + def test_success_message(self) -> None: """Test that a simple success_message is turned into a well-formed Flux Standard Action response""" msg = success_response(message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): @@ -59,7 +59,7 @@ def test_success_message(self): "payload": {"message": "test.first_msg", "success": True}, } - def test_success_message_with_data(self): + def test_success_message_with_data(self) -> None: """Test that a success_message with data is turned into a well-formed Flux Standard Action response""" msg = success_response(payload={"working": True}, message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): @@ -69,7 +69,7 @@ def test_success_message_with_data(self): "payload": {"message": "test.first_msg", "success": True, "working": True}, } - def test_error_message(self): + def test_error_message(self) -> None: """Test that a simple success_message is turned into a well-formed Flux Standard Action response""" msg = error_response(message=TestsMsg.fst_test_msg) with self.app.test_request_context("/test/foo"): diff --git a/src/eduid/webapp/common/api/tests/test_inputs.py b/src/eduid/webapp/common/api/tests/test_inputs.py index f86feb046..f4dcbe913 100644 --- a/src/eduid/webapp/common/api/tests/test_inputs.py +++ b/src/eduid/webapp/common/api/tests/test_inputs.py @@ -1,9 +1,9 @@ import logging from collections.abc import Mapping -from typing import Any +from typing import Any, NoReturn from urllib.parse import unquote -from flask import Blueprint, make_response, request +from flask import Blueprint, Response, make_response, request from marshmallow import ValidationError, fields from werkzeug.http import dump_cookie @@ -21,7 +21,7 @@ __author__ = "lundberg" -def dont_validate(value: Any): +def dont_validate(value: Any) -> NoReturn: raise ValidationError(f"Problem with {value!r}") @@ -43,45 +43,51 @@ def _make_response(data: str): @test_views.route("/test-get-param", methods=["GET"]) -def get_param_view(): +def get_param_view() -> Response: param = request.args.get("test-param") + assert param return _make_response(param) @test_views.route("/test-post-param", methods=["POST"]) -def post_param_view(): +def post_param_view() -> Response: param = request.form.get("test-param") + assert param return _make_response(param) -@test_views.route("/test-post-json", methods=["POST"]) +@test_views.route("/test-post-json", methods=["POST"]) # type: ignore[arg-type] @UnmarshalWith(NonValidatingSchema) -def post_json_view(test_data: str): +def post_json_view(test_data: str) -> None: """never validates""" pass @test_views.route("/test-cookie") -def cookie_view(): +def cookie_view() -> Response: cookie = request.cookies.get("test-cookie") + assert cookie return _make_response(cookie) @test_views.route("/test-empty-session") -def empty_session_view(): +def empty_session_view() -> Response: cookie = request.cookies.get("sessid") + assert cookie is not None return _make_response(cookie) @test_views.route("/test-header") -def header_view(): +def header_view() -> Response: header = request.headers.get("X-TEST") + assert header return _make_response(header) @test_views.route("/test-values", methods=["GET", "POST"]) -def values_view(): +def values_view() -> Response: param = request.values.get("test-param") + assert param return _make_response(param) @@ -112,27 +118,27 @@ def load_app(self, test_config: Mapping[str, Any]) -> InputsTestApp: app.register_blueprint(test_views) return app - def test_get_param(self): + def test_get_param(self) -> None: """""" url = "/test-get-param?test-param=test-param" with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() self.assertIn(b"test-param", response.data) - def test_get_param_script(self): + def test_get_param_script(self) -> None: """""" url = '/test-get-param?test-param=' with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() self.assertNotIn(b"'}): response = self.app.dispatch_request() self.assertNotIn(b"') @@ -205,7 +211,7 @@ def test_cookie_script(self): response = self.app.dispatch_request() self.assertNotIn(b"' @@ -213,21 +219,21 @@ def test_header_script(self): response = self.app.dispatch_request() self.assertNotIn(b"'}): response = self.app.dispatch_request() self.assertNotIn(b"