Skip to content

Commit

Permalink
refactor: update typing to python 3.10+ syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgenengelsen committed Mar 15, 2023
1 parent 7ef87c0 commit 245c699
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 48 deletions.
16 changes: 8 additions & 8 deletions api/src/authentication/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional
from typing import Any

from pydantic import BaseModel

Expand Down Expand Up @@ -28,7 +28,7 @@ def validate(cls, v):
raise ValueError("invalid AccessLevel enum value ")

@classmethod
def __modify_schema__(cls, schema: Dict[str, Any]):
def __modify_schema__(cls, schema: dict[str, Any]):
"""
Add a custom field type to the class representing the Enum's field names
Ref: https://pydantic-docs.helpmanual.io/usage/schema/#modifying-schema-in-custom-fields
Expand All @@ -43,13 +43,13 @@ def __modify_schema__(cls, schema: Dict[str, Any]):
class User(BaseModel):
user_id: str # If using azure AD authentication, user_id is the oid field from the access token.
# If using another oauth provider, user_id will be from the "sub" attribute in the access token.
email: Optional[str] = None
full_name: Optional[str] = None
roles: List[str] = []
email: str | None = None
full_name: str | None = None
roles: list[str] = []
scope: AccessLevel = AccessLevel.WRITE

def __hash__(self):
return hash((type(self.user_id)))
return hash(type(self.user_id))


class ACL(BaseModel):
Expand All @@ -64,8 +64,8 @@ class ACL(BaseModel):
"""

owner: str
roles: Dict[str, AccessLevel] = {}
users: Dict[str, AccessLevel] = {}
roles: dict[str, AccessLevel] = {}
users: dict[str, AccessLevel] = {}
others: AccessLevel = AccessLevel.READ

def dict(self, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion api/src/common/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Any, Callable
from collections.abc import Callable
from typing import Any

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
Expand Down
5 changes: 3 additions & 2 deletions api/src/common/responses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import traceback
from collections.abc import Callable
from inspect import iscoroutinefunction
from typing import Any, Callable, Type, TypeVar
from typing import Any, TypeVar

from httpx import HTTPStatusError
from starlette import status
Expand Down Expand Up @@ -44,7 +45,7 @@
"""


def create_response(response_class: Type[TResponse]) -> Callable:
def create_response(response_class: type[TResponse]) -> Callable:
def func_wrapper(func) -> Callable:
@functools.wraps(func)
async def wrapper_decorator(*args, **kwargs) -> TResponse | Response | JSONResponse:
Expand Down
10 changes: 5 additions & 5 deletions api/src/data_providers/clients/ClientInterface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, Generic, List, Optional, TypeVar
from typing import Generic, TypeVar

# Type definition for Model
M = TypeVar("M")
Expand Down Expand Up @@ -30,19 +30,19 @@ def update(self, id: K, instance: M) -> M:
pass

@abstractmethod
def insert_many(self, instances: List[M]):
def insert_many(self, instances: list[M]):
pass

@abstractmethod
def delete_many(self, filter: Dict):
def delete_many(self, filter: dict):
pass

@abstractmethod
def find(self, filter: Dict) -> M:
def find(self, filter: dict) -> M:
pass

@abstractmethod
def find_one(self, filter: Dict) -> Optional[M]:
def find_one(self, filter: dict) -> M | None:
pass

@abstractmethod
Expand Down
18 changes: 8 additions & 10 deletions api/src/data_providers/clients/mongodb/MongoDatabaseClient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict, List, Optional

from pymongo.cursor import Cursor
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError
Expand All @@ -9,7 +7,7 @@
from data_providers.clients.ClientInterface import ClientInterface


class MongoDatabaseClient(ClientInterface[Dict, str]):
class MongoDatabaseClient(ClientInterface[dict, str]):
def __init__(self, collection_name: str, database_name: str, client: MongoClient):
database: Database = client[database_name]
self.database = database
Expand All @@ -27,7 +25,7 @@ def wipe_db(self):
def delete_collection(self):
self.collection.drop()

def create(self, document: Dict) -> Dict:
def create(self, document: dict) -> dict:
try:
result = self.collection.insert_one(document)
return self.get(str(result.inserted_id))
Expand All @@ -37,14 +35,14 @@ def create(self, document: Dict) -> Dict:
def list_collection(self) -> list[dict]:
return list(self.collection.find())

def get(self, uid: str) -> Dict:
def get(self, uid: str) -> dict:
document = self.collection.find_one(filter={"_id": uid})
if document is None:
raise NotFoundException
else:
return dict(document)

def update(self, uid: str, document: Dict) -> Dict:
def update(self, uid: str, document: dict) -> dict:
if self.collection.find_one(filter={"_id": uid}) is None:
raise NotFoundException(extra={"uid": uid})
self.collection.replace_one({"_id": uid}, document)
Expand All @@ -54,14 +52,14 @@ def delete(self, uid: str) -> bool:
result = self.collection.delete_one(filter={"_id": uid})
return result.deleted_count > 0

def find(self, filter: Dict) -> Cursor:
def find(self, filter: dict) -> Cursor:
return self.collection.find(filter=filter)

def find_one(self, filter: Dict) -> Optional[Dict]:
def find_one(self, filter: dict) -> dict | None:
return self.collection.find_one(filter=filter)

def insert_many(self, items: List[Dict]):
def insert_many(self, items: list[dict]):
return self.collection.insert_many(items)

def delete_many(self, filter: Dict):
def delete_many(self, filter: dict):
return self.collection.delete_many(filter)
8 changes: 3 additions & 5 deletions api/src/data_providers/repositories/TodoRepository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from common.exceptions import NotFoundException
from data_providers.clients.ClientInterface import ClientInterface
from data_providers.repository_interfaces.TodoRepositoryInterface import (
Expand Down Expand Up @@ -36,17 +34,17 @@ def get(self, todo_item_id: str) -> TodoItem:
todo_item = self.client.get(todo_item_id)
return TodoItem.from_dict(todo_item)

def create(self, todo_item: TodoItem) -> Optional[TodoItem]:
def create(self, todo_item: TodoItem) -> TodoItem | None:
inserted_todo_item = self.client.create(to_dict(todo_item))
return TodoItem.from_dict(inserted_todo_item)

def get_all(self) -> list[TodoItem]:
todo_items = []
for item in self.client.list():
for item in self.client.list_collection():
todo_items.append(TodoItem.from_dict(item))
return todo_items

def find_one(self, filter: dict) -> Optional[TodoItem]:
def find_one(self, filter: dict) -> TodoItem | None:
todo_item = self.client.find_one(filter)
if todo_item:
return TodoItem.from_dict(todo_item)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

from entities.TodoItem import TodoItem


class TodoRepositoryInterface(metaclass=ABCMeta):
@abstractmethod
def create(self, todo_item: TodoItem) -> Optional[TodoItem]:
def create(self, todo_item: TodoItem) -> TodoItem | None:
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 1 addition & 3 deletions api/src/features/todo/todo_feature.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from fastapi import APIRouter, Depends
from starlette.responses import JSONResponse

Expand Down Expand Up @@ -58,7 +56,7 @@ def delete_todo_by_id(
@create_response(JSONResponse)
def get_todo_all(
user: User = Depends(auth_with_jwt), todo_repository: TodoRepositoryInterface = Depends(get_todo_repository)
) -> List[GetTodoAllResponse]:
) -> list[GetTodoAllResponse]:
return get_todo_all_use_case(user_id=user.user_id, todo_repository=todo_repository)


Expand Down
4 changes: 1 addition & 3 deletions api/src/features/todo/use_cases/get_todo_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from pydantic import BaseModel

from data_providers.repository_interfaces.TodoRepositoryInterface import (
Expand All @@ -21,7 +19,7 @@ def from_entity(todo_item: TodoItem):
def get_todo_all_use_case(
user_id: str,
todo_repository: TodoRepositoryInterface,
) -> List[GetTodoAllResponse]:
) -> list[GetTodoAllResponse]:
return [
GetTodoAllResponse.from_entity(todo_item)
for todo_item in todo_repository.get_all()
Expand Down
4 changes: 1 addition & 3 deletions api/src/tests/unit/features/todo/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Dict

import pytest

from data_providers.repositories.TodoRepository import TodoRepository


@pytest.fixture(scope="function")
def todo_test_data() -> Dict[str, dict]:
def todo_test_data() -> dict[str, dict]:
return {
"dh2109": {"_id": "dh2109", "title": "item 1", "is_completed": False, "user_id": "xyz"},
"1417b8": {"_id": "1417b8", "title": "item 2", "is_completed": True, "user_id": "xyz"},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Dict

from data_providers.repository_interfaces.TodoRepositoryInterface import (
TodoRepositoryInterface,
)
from features.todo.use_cases.get_todo_all import get_todo_all_use_case


def test_get_todos_should_return_todos(todo_repository: TodoRepositoryInterface, todo_test_data: Dict[str, dict]):
def test_get_todos_should_return_todos(todo_repository: TodoRepositoryInterface, todo_test_data: dict[str, dict]):
todos = get_todo_all_use_case(user_id="xyz", todo_repository=todo_repository)
assert len(todos) == len(todo_test_data.keys())
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

import pytest as pytest

from common.exceptions import NotFoundException
Expand All @@ -12,7 +10,7 @@
)


def test_get_todo_by_id_should_return_todo(todo_repository: TodoRepositoryInterface, todo_test_data: Dict[str, dict]):
def test_get_todo_by_id_should_return_todo(todo_repository: TodoRepositoryInterface, todo_test_data: dict[str, dict]):
id = "dh2109"
todo: GetTodoByIdResponse = get_todo_by_id_use_case(id, user_id="xyz", todo_repository=todo_repository)
assert todo.title == todo_test_data[id]["title"]
Expand Down

0 comments on commit 245c699

Please sign in to comment.