From 9397108b2724b6cbcf7905224abf80e030ad689e Mon Sep 17 00:00:00 2001 From: Pim Witlox Date: Mon, 4 Nov 2024 13:31:55 +0100 Subject: [PATCH] fixing serialization, not done yet --- horao/__init__.py | 15 +++--- horao/api/synchronization.py | 23 +++++--- horao/auth/__init__.py | 2 +- horao/auth/multi.py | 96 ++++++++++++++++++++++++++++++++++ horao/auth/peer.py | 85 ------------------------------ horao/conceptual/crdt.py | 11 ++++ horao/persistance/serialize.py | 84 +++++++++++++++-------------- tests/test_api.py | 15 +++--- tests/test_persistance.py | 33 ++++++++++-- 9 files changed, 209 insertions(+), 155 deletions(-) create mode 100644 horao/auth/multi.py delete mode 100644 horao/auth/peer.py diff --git a/horao/__init__.py b/horao/__init__.py index 926a70f..71383c3 100644 --- a/horao/__init__.py +++ b/horao/__init__.py @@ -19,6 +19,7 @@ from opentelemetry.sdk.metrics.export import ( PeriodicExportingMetricReader, # type: ignore ) +from starlette.authentication import AuthenticationBackend if os.getenv("OLTP_HTTP", "False") == "False": from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( @@ -45,8 +46,7 @@ import horao.api import horao.api.synchronization import horao.auth -from horao.auth.basic import BasicAuthBackend -from horao.auth.peer import PeerAuthBackend +from horao.auth.multi import MultiAuthBackend LoggingInstrumentor().instrument(set_logging_format=True) @@ -121,9 +121,10 @@ async def docs(request): return HTMLResponse(html) -def init() -> Starlette: +def init(authorization: AuthenticationBackend = None) -> Starlette: """ Initialize the API + authorization: optional authorization backend to overwrite default behavior (useful for testing) :return: app instance """ if os.getenv("DEBUG", "False") == "True": @@ -148,11 +149,13 @@ def init() -> Starlette: routes.append(Route("/docs", endpoint=docs, methods=["GET"])) middleware = [ Middleware(CORSMiddleware, allow_origins=[cors]), - Middleware(AuthenticationMiddleware, backend=PeerAuthBackend()), ] - if os.getenv("AUTH", "basic") == "basic": + if authorization: + logger.warning(f"Using custom authorization backend: {type(authorization)}") + middleware.append(Middleware(AuthenticationMiddleware, backend=authorization)) + else: middleware.append( - Middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) + Middleware(AuthenticationMiddleware, backend=MultiAuthBackend()) ) app = Starlette( routes=routes, diff --git a/horao/api/synchronization.py b/horao/api/synchronization.py index 87cc799..586cea8 100644 --- a/horao/api/synchronization.py +++ b/horao/api/synchronization.py @@ -2,6 +2,7 @@ """All calls needed for synchronizing HORAO instances.""" import json import logging +import os from starlette.authentication import requires from starlette.requests import Request @@ -32,25 +33,33 @@ async def synchronize(request: Request) -> JSONResponse: """ logging.debug(f"Calling Synchronize ({request})") try: - data = await request.json() + data = await request.body() + logical_infrastructure = json.loads(data, cls=HoraoDecoder) except Exception as e: logging.error(f"Error parsing request: {e}") + if os.getenv("DEBUG", "False") == "True": + return JSONResponse( + status_code=400, content={"error": f"Error parsing request {str(e)}"} + ) return JSONResponse(status_code=400, content={"error": "Error parsing request"}) try: - logical_infrastructure = json.loads(data, cls=HoraoDecoder) session = init_session() for k, v in logical_infrastructure.infrastructure.items(): - local_dc = session.load(k.id) + local_dc = session.load(k.identity) if not local_dc: - session.save(k.id, k) + session.save(k.identity, k) else: local_dc.merge(k) - local_dc_content = session.load(f"{k.id}.content") + local_dc_content = session.load(f"{k.identity}.content") if not local_dc_content: - session.save(f"{k.id}.content", v) + session.save(f"{k.identity}.content", v) else: local_dc_content.merge(v) except Exception as e: logging.error(f"Error synchronizing: {e}") - return JSONResponse(status_code=500, content={"error": "Error synchronizing"}) + if os.getenv("DEBUG", "False") == "True": + return JSONResponse( + status_code=500, content={"error": f"Error synchronizing {str(e)}"} + ) + return JSONResponse(status_code=500, content={"error": f"Error synchronizing"}) return JSONResponse(status_code=200, content={"status": "is alive"}) diff --git a/horao/auth/__init__.py b/horao/auth/__init__.py index deefd3a..8837012 100644 --- a/horao/auth/__init__.py +++ b/horao/auth/__init__.py @@ -6,4 +6,4 @@ various implementations that can be used, but some are only meant for development purpose. """ from horao.auth.basic import BasicAuthBackend -from horao.auth.peer import PeerAuthBackend, Peer +from horao.auth.multi import MultiAuthBackend, Peer diff --git a/horao/auth/multi.py b/horao/auth/multi.py new file mode 100644 index 0000000..253ef55 --- /dev/null +++ b/horao/auth/multi.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*-# +"""Authorization for peers. + +Digest authentication using pre-shared key. +""" +import binascii +import logging +import os +from typing import Tuple, Union + +import jwt +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + AuthenticationError, + BaseUser, +) +from starlette.requests import HTTPConnection + + +class Peer(BaseUser): + def __init__(self, identity: str, token: str, payload, origin: str) -> None: + self.id = identity + self.token = token + self.payload = payload + self.origin = origin + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.origin + + @property + def identity(self) -> str: + return self.id + + def is_true(self) -> bool: + """ + Check if the identity matches the origin. + :return: bool + """ + return self.identity == self.origin + + def __str__(self) -> str: + return f"{self.origin} -> {self.identity}" + + +class MultiAuthBackend(AuthenticationBackend): + logger = logging.getLogger(__name__) + + def digest_authentication( + self, conn: HTTPConnection, token: str + ) -> Union[None, Tuple[AuthCredentials, BaseUser]]: + peer_match_source = False + for peer in os.getenv("PEERS").split(","): # type: ignore + if peer in conn.client.host: + self.logger.debug(f"Peer {peer} is trying to authenticate") + peer_match_source = True + if not peer_match_source and os.getenv("PEER_STRICT", "True") == "True": + raise AuthenticationError(f"access not allowed for {conn.client.host}") + payload = jwt.decode(token, os.getenv("PEER_SECRET"), algorithms=["HS256"]) # type: ignore + self.logger.debug(f"valid token for {payload['peer']}") + return AuthCredentials(["authenticated_peer"]), Peer( + identity=payload["peer"], + token=token, + payload=payload, + origin=conn.client.host, + ) + + async def authenticate( + self, conn: HTTPConnection + ) -> Union[None, Tuple[AuthCredentials, BaseUser]]: + if "Authorization" not in conn.headers: + return None + if "PEERS" not in os.environ: + return None + if "PEER_SECRET" not in os.environ: + return None + + auth = conn.headers["Authorization"] + try: + scheme, token = auth.split() + if scheme.lower() != "bearer": + return None + return self.digest_authentication(conn, token) + except ( + ValueError, + UnicodeDecodeError, + jwt.InvalidTokenError, + binascii.Error, + ) as exc: + self.logger.error(f"Invalid token for peer ({exc})") + raise AuthenticationError(f"access not allowed for {conn.client.host}") diff --git a/horao/auth/peer.py b/horao/auth/peer.py deleted file mode 100644 index 43a1eba..0000000 --- a/horao/auth/peer.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*-# -"""Authorization for peers. - -Digest authentication using pre-shared key. -""" -import binascii -import logging -import os -from typing import Tuple, Union - -import jwt -from starlette.authentication import ( - AuthCredentials, - AuthenticationBackend, - AuthenticationError, - BaseUser, -) - - -class Peer(BaseUser): - def __init__(self, id: str, token: str, payload, origin: str) -> None: - self.id = id - self.token = token - self.payload = payload - self.origin = origin - - @property - def is_authenticated(self) -> bool: - return True - - @property - def display_name(self) -> str: - return self.origin - - @property - def identity(self) -> str: - return self.id - - def is_true(self) -> bool: - return self.id == self.origin - - def __str__(self) -> str: - return f"{self.origin} -> {self.id}" - - -class PeerAuthBackend(AuthenticationBackend): - logger = logging.getLogger(__name__) - - async def authenticate(self, conn) -> Union[None, Tuple[AuthCredentials, BaseUser]]: - if "Authorization" not in conn.headers: - return None - if "PEERS" in os.environ: - return None - if "PEER_SECRET" not in os.environ: - return None - - auth = conn.headers["Authorization"] - try: - scheme, token = auth.split() - if scheme.lower() != "jwt": - return None - - peer_match_source = False - for peer in os.getenv("PEERS").split(","): # type: ignore - if peer in conn.client.host: - self.logger.debug(f"Peer {peer} is trying to authenticate") - peer_match_source = True - if not peer_match_source and os.getenv("PEER_STRICT", "True") == "True": - raise AuthenticationError(f"access not allowed for {conn.client.host}") - payload = jwt.decode(token, os.getenv("PEER_SECRET"), algorithms=["HS256"]) # type: ignore - self.logger.debug(f"valid token for {payload['peer']}") - return AuthCredentials(["authenticated_peer"]), Peer( - id=payload["peer"], - token=token, - payload=payload, - origin=conn.client.host, - ) - except ( - ValueError, - UnicodeDecodeError, - jwt.InvalidTokenError, - binascii.Error, - ) as exc: - self.logger.error(f"Invalid token for peer ({exc})") - raise AuthenticationError(f"access not allowed for {conn.client.host}") diff --git a/horao/conceptual/crdt.py b/horao/conceptual/crdt.py index 2e7db97..757b2f0 100644 --- a/horao/conceptual/crdt.py +++ b/horao/conceptual/crdt.py @@ -453,6 +453,17 @@ def invoke_listeners(self, state_update: Update) -> None: for listener in self.listeners: listener(state_update) + def __eq__(self, other): + if not isinstance(other, LastWriterWinsRegister): + return False + return ( + self.name == other.name + and self.clock == other.clock + and self.value == other.value + and self.last_update == other.last_update + and self.last_writer == other.last_writer + ) + class LastWriterWinsMap(CRDT): """Last Writer Wins Map CRDT.""" diff --git a/horao/persistance/serialize.py b/horao/persistance/serialize.py index fa6cf86..418caa5 100644 --- a/horao/persistance/serialize.py +++ b/horao/persistance/serialize.py @@ -2,7 +2,9 @@ """Serialize and Deserialize Horao objects to JSON""" import json from datetime import date, datetime +from json import JSONDecodeError +from networkx.algorithms.structuralholes import constraint from networkx.convert import from_dict_of_dicts, to_dict_of_dicts # type: ignore from horao.conceptual.claim import Reservation @@ -94,23 +96,20 @@ def default(self, obj): ), } if isinstance(obj, LastWriterWinsMap): - result = { + return { "type": "LastWriterWinsMap", "names": json.dumps(obj.names, cls=HoraoEncoder) if obj.names else None, + "registers": ( + json.dumps(obj.registers, cls=HoraoEncoder) + if obj.registers + else None + ), "listeners": ( json.dumps(obj.listeners, cls=HoraoEncoder) if obj.listeners else None ), } - registers = {} - if obj.registers: - for k, v in obj.registers.items(): - registers[json.dumps(k, cls=HoraoEncoder)] = json.dumps( - v, cls=HoraoEncoder - ) - result["registers"] = registers - return result if isinstance(obj, HardwareList): return { "type": "HardwareList", @@ -333,24 +332,19 @@ def default(self, obj): "hsn": obj.hsn if obj.hsn else False, } if isinstance(obj, LogicalInfrastructure): + tenants = list(set(list(obj.constraints.keys()) + list(obj.claims.keys()))) + data_centers = list(obj.infrastructure.keys()) + infrastructure = {k.name: v for k, v in obj.infrastructure.items()} + claims = {k.name: v for k, v in obj.claims.items()} + constraints = {k.name: v for k, v in obj.constraints.items()} result = { "type": "LogicalInfrastructure", - "infrastructure": {}, - "constraints": {}, - "claims": {}, + "tenants": json.dumps(tenants, cls=HoraoEncoder), + "data_centers": json.dumps(data_centers, cls=HoraoEncoder), + "infrastructure": json.dumps(infrastructure, cls=HoraoEncoder), + "constraints": json.dumps(constraints, cls=HoraoEncoder), + "claims": json.dumps(claims, cls=HoraoEncoder), } - for k, v in obj.infrastructure.items(): - result["infrastructure"][json.dumps(k, cls=HoraoEncoder)] = json.dumps( - v, cls=HoraoEncoder - ) - for k, v in obj.constraints.items(): - result["constraints"][json.dumps(k, cls=HoraoEncoder)] = json.dumps( - v, cls=HoraoEncoder - ) - for k, v in obj.claims.items(): - result["claims"][json.dumps(k, cls=HoraoEncoder)] = json.dumps( - v, cls=HoraoEncoder - ) return result if isinstance(obj, Storage): return { @@ -390,7 +384,6 @@ def default(self, obj): "name": obj.name, "start": obj.start, "end": obj.end, - "end_user": obj.end_user, "resources": json.dumps(obj.resources, cls=HoraoEncoder), "maximal_resources": ( json.dumps(obj.maximal_resources, cls=HoraoEncoder) @@ -478,21 +471,20 @@ def object_hook(obj): ), ) if "type" in obj and obj["type"] == "LastWriterWinsMap": - registers = {} - for k, v in obj["registers"].items(): - registers[json.loads(k, cls=HoraoDecoder)] = json.loads( - v, cls=HoraoDecoder - ) return LastWriterWinsMap( names=( json.loads(obj["names"], cls=HoraoDecoder) if obj["names"] else None ), + registers=( + json.loads(obj["registers"], cls=HoraoDecoder) + if obj["registers"] + else None + ), listeners=( json.loads(obj["listeners"], cls=HoraoDecoder) if obj["listeners"] else None ), - registers=registers, ) if "type" in obj and obj["type"] == "HardwareList": return HardwareList(items=json.loads(obj["hardware"], cls=HoraoDecoder)) @@ -708,21 +700,28 @@ def object_hook(obj): dcn.links_from_graph(from_dict_of_dicts(json.loads(obj["graph"]))) return dcn if "type" in obj and obj["type"] == "LogicalInfrastructure": + tenants = json.loads(obj["tenants"], cls=HoraoDecoder) + data_centers = json.loads(obj["data_centers"], cls=HoraoDecoder) infrastructure = {} - for k, v in obj["infrastructure"].items(): - infrastructure[json.loads(k, cls=HoraoDecoder)] = json.loads( - v, cls=HoraoDecoder + for k, v in json.loads(obj["infrastructure"], cls=HoraoDecoder): + data_centre = next( + iter([dc for dc in data_centers if dc.name == k]), None ) + if not data_centre: + raise JSONDecodeError(f"DataCenter {k} not found", obj, 0) + infrastructure[data_centre] = v constraints = {} - for k, v in obj["constraints"].items(): - constraints[json.loads(k, cls=HoraoDecoder)] = json.loads( - v, cls=HoraoDecoder - ) + for k, v in json.loads(obj["constraints"], cls=HoraoDecoder): + tenant = next(iter([t for t in tenants if t.name == k]), None) + if not tenant: + raise JSONDecodeError(f"Tenant {k} not found", obj, 0) + constraints[tenant] = v claims = {} - for k, v in obj["claims"].items(): - claims[json.loads(k, cls=HoraoDecoder)] = json.loads( - v, cls=HoraoDecoder - ) + for k, v in json.loads(obj["claims"], cls=HoraoDecoder): + tenant = next(iter([t for t in tenants if t.name == k]), None) + if not tenant: + raise JSONDecodeError(f"Tenant {k} not found", obj, 0) + claims[tenant] = v return LogicalInfrastructure( infrastructure=infrastructure, constraints=constraints, @@ -759,7 +758,6 @@ def object_hook(obj): name=obj["name"], start=obj["start"], end=obj["end"], - end_user=obj["end_user"], resources=json.loads(obj["resources"], cls=HoraoDecoder), maximal_resources=( json.loads(obj["maximal_resources"], cls=HoraoDecoder) diff --git a/tests/test_api.py b/tests/test_api.py index 14b99b8..75ff2f9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,6 +6,7 @@ from starlette.testclient import TestClient from horao import init +from horao.auth import BasicAuthBackend from horao.auth.basic import basic_auth from horao.logical.infrastructure import LogicalInfrastructure from horao.persistance import HoraoEncoder @@ -14,7 +15,7 @@ def test_ping_service_unauthorized(): os.environ["TELEMETRY"] = "OFF" - ia = init() + ia = init(BasicAuthBackend()) with TestClient(ia) as client: lg = client.get("/ping") assert 403 == lg.status_code @@ -22,7 +23,7 @@ def test_ping_service_unauthorized(): def test_ping_service_authorized(): os.environ["TELEMETRY"] = "OFF" - ia = init() + ia = init(BasicAuthBackend()) with TestClient(ia) as client: lg = client.get( "/ping", headers={"Authorization": basic_auth("netadm", "secret")} @@ -31,6 +32,7 @@ def test_ping_service_authorized(): def test_synchronize_simple_structure(): + os.environ["DEBUG"] = "True" os.environ["TELEMETRY"] = "OFF" os.environ["PEER_STRICT"] = "False" os.environ["PEERS"] = "1,2" @@ -44,10 +46,7 @@ def test_synchronize_simple_structure(): assert 403 == lg.status_code lg = client.post( "/synchronize", - headers={"Authorization": f"Token {token}"}, - json={ - "LogicalInfrastructure": json.dumps(infrastructure, cls=HoraoEncoder) - }, + headers={"Authorization": f"Bearer {token}"}, + json=json.dumps(infrastructure, cls=HoraoEncoder), ) - # todo still need to fix - # assert 200 == lg.status_code + assert 200 == lg.status_code diff --git a/tests/test_persistance.py b/tests/test_persistance.py index ea918b4..f91fead 100644 --- a/tests/test_persistance.py +++ b/tests/test_persistance.py @@ -1,6 +1,6 @@ import pytest -from horao.conceptual.crdt import LastWriterWinsMap +from horao.conceptual.crdt import LastWriterWinsMap, LastWriterWinsRegister from horao.conceptual.support import LogicalClock from horao.logical.infrastructure import LogicalInfrastructure from horao.persistance.store import Store @@ -14,8 +14,31 @@ async def test_storing_loading_logical_clock(): clock = LogicalClock() store = Store(None) await store.save("clock", clock) - loaded_clock = store.load("clock") - assert clock == await loaded_clock + loaded_clock = await store.load("clock") + assert clock == loaded_clock + + +@pytest.mark.asyncio +async def test_storing_loading_observed_removed_set(): + observed_removed_set = set() + observed_removed_set.add("foo") + observed_removed_set.add("bar") + store = Store(None) + await store.save("observed_removed_set", observed_removed_set) + loaded_observed_removed_set = await store.load("observed_removed_set") + assert observed_removed_set == loaded_observed_removed_set + + +@pytest.mark.asyncio +async def test_storing_loading_last_writer_wins_register(): + lww_register = LastWriterWinsRegister("test", "foo") + update = lww_register.write("foobar", 1) + lww_register.update(update) + + store = Store(None) + await store.save("lww_register", lww_register) + loaded_lww_register = await store.load("lww_register") + assert lww_register == loaded_lww_register @pytest.mark.asyncio @@ -26,8 +49,8 @@ async def test_storing_loading_last_writer_wins_map(): lww_map.set(name, value, 1) store = Store(None) await store.save("lww_map", lww_map) - loaded_lww_map = store.load("lww_map") - assert lww_map == await loaded_lww_map + loaded_lww_map = await store.load("lww_map") + assert lww_map == loaded_lww_map @pytest.mark.asyncio