Skip to content

Commit

Permalink
Merge pull request #705 from SUNET/ylle_annotations
Browse files Browse the repository at this point in the history
Adding annotations for type safety
  • Loading branch information
johanlundberg authored Oct 4, 2024
2 parents 2c5cf8d + 80306de commit 149a1d0
Show file tree
Hide file tree
Showing 377 changed files with 4,308 additions and 3,332 deletions.
8 changes: 6 additions & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ line-length = 120
target-version = "py310"

[lint]
select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA"]
select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN"]

ignore = ["E501"]
# ANN101 and ANN102 are depracated
ignore = ["E501", "ANN1"]

[lint.flake8-annotations]
allow-star-arg-any = true
5 changes: 3 additions & 2 deletions src/eduid/common/clients/amapi_client/amapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__author__ = "masv"

from eduid.common.models.amapi_user import (
UserBaseRequest,
UserUpdateEmailRequest,
UserUpdateLanguageRequest,
UserUpdateMetaCleanedRequest,
Expand All @@ -20,14 +21,14 @@


class AMAPIClient(GNAPClient):
def __init__(self, amapi_url: str, auth_data=GNAPClientAuthData, verify_tls: bool = True, **kwargs):
def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs: Any) -> None:
super().__init__(auth_data=auth_data, verify=verify_tls, **kwargs)
self.amapi_url = amapi_url

def _users_base_url(self) -> str:
return urlappend(self.amapi_url, "users")

def _put(self, base_path: str, user: str, endpoint: str, body: Any) -> httpx.Response:
def _put(self, base_path: str, user: str, endpoint: str, body: UserBaseRequest) -> httpx.Response:
return self.put(url=urlappend(base_path, f"{user}/{endpoint}"), content=body.json())

def update_user_email(self, user: str, body: UserUpdateEmailRequest) -> UserUpdateResponse:
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/clients/amapi_client/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class MockedAMAPIMixin(MockedSyncAuthAPIMixin):
def start_mock_amapi(self, access_token_value: str | None = None):
def start_mock_amapi(self, access_token_value: str | None = None) -> None:
self.start_mock_auth_api(access_token_value=access_token_value)

self.mocked_users = respx.mock(base_url="http://localhost", assert_all_called=False)
Expand Down
3 changes: 2 additions & 1 deletion src/eduid/common/clients/gnap_client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

import httpx

Expand All @@ -11,7 +12,7 @@


class AsyncGNAPClient(httpx.AsyncClient, GNAPBearerTokenMixin):
def __init__(self, auth_data: GNAPClientAuthData, **kwargs):
def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any) -> None:
if "event_hooks" not in kwargs:
kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]}

Expand Down
3 changes: 2 additions & 1 deletion src/eduid/common/clients/gnap_client/sync_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

import httpx

Expand All @@ -11,7 +12,7 @@


class GNAPClient(httpx.Client, GNAPBearerTokenMixin):
def __init__(self, auth_data: GNAPClientAuthData, **kwargs):
def __init__(self, auth_data: GNAPClientAuthData, **kwargs: Any) -> None:
if "event_hooks" not in kwargs:
kwargs["event_hooks"] = {"response": [self.raise_on_4xx_5xx], "request": [self._add_authz_header]}

Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/clients/gnap_client/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class MockedSyncAuthAPIMixin:
def start_mock_auth_api(self, access_token_value: str | None = None):
def start_mock_auth_api(self, access_token_value: str | None = None) -> None:
if access_token_value is None:
access_token_value = "mock_jwt"
self.mocked_auth_api = respx.mock(base_url="http://localhost/auth", assert_all_called=False)
Expand Down
3 changes: 2 additions & 1 deletion src/eduid/common/clients/scim_client/scim_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any
from uuid import UUID

import httpx
Expand All @@ -20,7 +21,7 @@ class SCIMError(Exception):


class SCIMClient(GNAPClient):
def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs):
def __init__(self, scim_api_url: str, auth_data: GNAPClientAuthData, **kwargs: Any) -> None:
super().__init__(auth_data=auth_data, **kwargs)
self.event_hooks["request"].append(self._add_accept_header)
self.scim_api_url = scim_api_url
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/clients/scim_client/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MockedScimAPIMixin(MockedSyncAuthAPIMixin):

put_user_response = post_user_response

def start_mocked_scim_api(self):
def start_mocked_scim_api(self) -> None:
self.start_mock_auth_api()

self.mocked_scim_api = respx.mock(base_url="http://localhost/scim", assert_all_called=False)
Expand Down
4 changes: 2 additions & 2 deletions src/eduid/common/config/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class BadConfiguration(Exception):
def __init__(self, message):
def __init__(self, message: str) -> None:
Exception.__init__(self)
self.value = message

def __str__(self):
def __str__(self) -> str:
return self.value
10 changes: 5 additions & 5 deletions src/eduid/common/config/parsers/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from functools import wraps
from string import Template
from typing import Any
Expand All @@ -11,9 +11,9 @@
from eduid.common.config.parsers.exceptions import SecretKeyException


def decrypt(f):
def decrypt(f: Callable) -> Callable:
@wraps(f)
def decrypt_decorator(*args, **kwargs):
def decrypt_decorator(*args: Any, **kwargs: Any) -> Mapping[str, Any]:
config_dict = f(*args, **kwargs)
decrypted_config_dict = decrypt_config(config_dict)
return decrypted_config_dict
Expand Down Expand Up @@ -83,9 +83,9 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]:
return new_config_dict


def interpolate(f):
def interpolate(f: Callable) -> Callable:
@wraps(f)
def interpolation_decorator(*args, **kwargs):
def interpolation_decorator(*args: Any, **kwargs: Any) -> dict[str, Any]:
config_dict = f(*args, **kwargs)
interpolated_config_dict = interpolate_config(config_dict)
for key in list(interpolated_config_dict.keys()):
Expand Down
4 changes: 2 additions & 2 deletions src/eduid/common/config/parsers/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@


class ParserException(Exception):
def __init__(self, message):
def __init__(self, message: str) -> None:
Exception.__init__(self)
self.value = message

def __str__(self):
def __str__(self) -> str:
return self.value


Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/config/parsers/yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class YamlConfigParser(BaseConfigParser):
def __init__(self, path: Path):
def __init__(self, path: Path) -> None:
self.path = path

@interpolate
Expand Down
4 changes: 2 additions & 2 deletions src/eduid/common/config/tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


class TestInitConfig(unittest.TestCase):
def tearDown(self):
def tearDown(self) -> None:
os.environ.clear()

def test_YamlConfigParser(self):
def test_YamlConfigParser(self) -> None:
os.environ["EDUID_CONFIG_NS"] = "/test/ns/"
os.environ["EDUID_CONFIG_YAML"] = "/config.yaml"
parser = _choose_parser()
Expand Down
14 changes: 7 additions & 7 deletions src/eduid/common/config/tests/test_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class TestInitConfig(unittest.TestCase):
def setUp(self) -> None:
self.data_dir = PurePath(__file__).with_name("data")

def tearDown(self):
def tearDown(self) -> None:
os.environ.clear()

def test_YamlConfig(self):
def test_YamlConfig(self) -> None:
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/app_one"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml")
Expand All @@ -49,7 +49,7 @@ def test_YamlConfig(self):
assert config_two.number == 10
assert config_two.only_default == 19

def test_YamlConfig_interpolation(self):
def test_YamlConfig_interpolation(self) -> None:
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_interpolation"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml")
Expand All @@ -58,7 +58,7 @@ def test_YamlConfig_interpolation(self):
assert config.number == 3
assert config.foo == "hi world"

def test_YamlConfig_missing_value(self):
def test_YamlConfig_missing_value(self) -> None:
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_missing_value"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml")
Expand All @@ -75,7 +75,7 @@ def test_YamlConfig_missing_value(self):
}
], f"Wrong error message: {exc_info.value.errors()}"

def test_YamlConfig_wrong_type(self):
def test_YamlConfig_wrong_type(self) -> None:
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_wrong_type"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
os.environ["EDUID_CONFIG_YAML"] = str(self.data_dir / "test.yaml")
Expand All @@ -92,7 +92,7 @@ def test_YamlConfig_wrong_type(self):
}
], f"Wrong error message: {exc_info.value.errors()}"

def test_YamlConfig_unknown_data(self):
def test_YamlConfig_unknown_data(self) -> None:
"""Unknown data should not be rejected because it is an operational nightmare"""
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_unknown_data"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
Expand All @@ -102,7 +102,7 @@ def test_YamlConfig_unknown_data(self):
assert config.number == 0xFF
assert config.foo == "bar"

def test_YamlConfig_mixed_case_keys(self):
def test_YamlConfig_mixed_case_keys(self) -> None:
"""For legacy reasons, all keys should be lowercased"""
os.environ["EDUID_CONFIG_NS"] = "/eduid/test/test_mixed_case_keys"
os.environ["EDUID_CONFIG_COMMON_NS"] = "/eduid/test/common"
Expand Down
10 changes: 6 additions & 4 deletions src/eduid/common/decorators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any


# https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488
def deprecated(reason):
def deprecated(reason: str | type | Callable) -> Callable:
"""
This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
Expand All @@ -20,14 +22,14 @@ def deprecated(reason):
# def old_function(x, y):
# pass

def decorator(func1):
def decorator(func1: Callable) -> Callable:
if inspect.isclass(func1):
fmt1 = "Call to deprecated class {name} ({reason})."
else:
fmt1 = "Call to deprecated function {name} ({reason})."

@wraps(func1)
def new_func1(*args, **kwargs):
def new_func1(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2
Expand Down Expand Up @@ -56,7 +58,7 @@ def new_func1(*args, **kwargs):
fmt2 = "Call to deprecated function {name}."

@wraps(func2)
def new_func2(*args, **kwargs):
def new_func2(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2)
warnings.simplefilter("default", DeprecationWarning)
Expand Down
9 changes: 5 additions & 4 deletions src/eduid/common/fastapi/context_request.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from collections.abc import Callable
from dataclasses import asdict, dataclass
from typing import Any

from fastapi import Request, Response
from fastapi.routing import APIRoute


@dataclass
class Context:
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return asdict(self)


class ContextRequest(Request):
def __init__(self, context_class: type[Context], *args, **kwargs):
def __init__(self, context_class: type[Context], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.contextClass = context_class

@property
def context(self):
def context(self) -> Context:
try:
return self.state.context
except AttributeError:
Expand All @@ -26,7 +27,7 @@ def context(self):
return self.context

@context.setter
def context(self, context: Context):
def context(self, context: Context) -> None:
self.state.context = context


Expand Down
15 changes: 8 additions & 7 deletions src/eduid/common/fastapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import uuid
from typing import Any

from fastapi import Request, status
from fastapi.exception_handlers import http_exception_handler
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self,
status_code: int,
detail: str | None = None,
):
) -> None:
self._error_detail = ErrorDetail(detail=detail, status=status_code)
self._extra_headers: dict | None = None

Expand All @@ -69,41 +70,41 @@ def extra_headers(self) -> dict | None:
return self._extra_headers

@extra_headers.setter
def extra_headers(self, headers: dict):
def extra_headers(self, headers: dict) -> None:
self._extra_headers = headers


class BadRequest(HTTPErrorDetail):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(status_code=status.HTTP_400_BAD_REQUEST, **kwargs)
if not self.error_detail.detail:
self.error_detail.detail = "Bad Request"


class Unauthorized(HTTPErrorDetail):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, **kwargs)
if not self.error_detail.detail:
self.error_detail.detail = "Unauthorized request"


class NotFound(HTTPErrorDetail):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(status_code=status.HTTP_404_NOT_FOUND, **kwargs)
if not self.error_detail.detail:
self.error_detail.detail = "Resource not found"


class MethodNotAllowedMalformed(HTTPErrorDetail):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, **kwargs)
if not self.error_detail.detail:
allowed_methods = kwargs.get("allowed_methods")
self.error_detail.detail = f"The used HTTP method is not allowed. Allowed methods: {allowed_methods}"


class UnsupportedMediaTypeMalformed(HTTPErrorDetail):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs)
if not self.error_detail.detail:
self.error_detail.detail = "Request was made with an unsupported media type"
Loading

0 comments on commit 149a1d0

Please sign in to comment.