From 3acfa15bd7764e23d684824ab508bb936eff6a18 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 2 Apr 2024 17:02:41 +0200 Subject: [PATCH] Issue #139 WIP --- src/openeo_aggregator/backend.py | 65 +++++++++++++++++++- src/openeo_aggregator/config.py | 8 +++ src/openeo_aggregator/utils.py | 14 +++++ tests/test_views.py | 102 ++++++++++++++++++++++++++++++- 4 files changed, 186 insertions(+), 3 deletions(-) diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index ca01f21..3799d11 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import contextlib +import dataclasses import datetime import functools import logging @@ -106,6 +109,7 @@ dict_merge, is_whitelisted, normalize_issuer_url, + string_or_regex_match, subdict, ) @@ -139,6 +143,54 @@ def __jsonserde_load__(cls, data: dict): return cls(data=data) +@dataclasses.dataclass(frozen=True) +class CollectionAllowItem: + """ + Item in the collection allow list. + At least contains a collection id (string or regex pattern) + and optionally a list of allowed backends. + """ + + collection_id: Union[str, re.Pattern] + allowed_backends: Optional[List[str]] = None + # TODO: support deny list too? + + @staticmethod + def parse(item: Union[str, re.Pattern, dict]) -> CollectionAllowItem: + """Parse given item data""" + if isinstance(item, (str, re.Pattern)): + return CollectionAllowItem(collection_id=item) + elif isinstance(item, dict): + return CollectionAllowItem(**item) + else: + raise TypeError(f"Invalid item type {type(item)}") + + def match(self, collection_id: str, backend_id: str) -> bool: + """Check if given collection/backend pair matches this item""" + collection_ok = string_or_regex_match(pattern=self.collection_id, value=collection_id) + backend_ok = self.allowed_backends is None or backend_id in self.allowed_backends + return collection_ok and backend_ok + + +class CollectionAllowList: + """Allow list for collections, where filtering is based on collection id and (optionally) backend id.""" + + def __init__(self, items: List[Union[str, re.Pattern, dict]]): + """ + :param items: list of allow list items, where each item can be: + - string (collection id) + - regex pattern for collection id + - dict with: + - required key "collection_id" (string or regex) + - optional "allowed_backends": list of backends to consider for this collection + """ + self.items: List[CollectionAllowItem] = [CollectionAllowItem.parse(item) for item in items] + + def is_allowed(self, collection_id: str, backend_id: str) -> bool: + """Check if given collection is allowed""" + return any(item.match(collection_id=collection_id, backend_id=backend_id) for item in self.items) + + class AggregatorCollectionCatalog(AbstractCollectionCatalog): def __init__(self, backends: MultiBackendConnection): self.backends = backends @@ -161,7 +213,16 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]: """ # Group collection metadata by hierarchically: collection id -> backend id -> metadata grouped = defaultdict(dict) + # TODO: remove this deprecated collection_whitelist collection_whitelist: Optional[List[Union[str, re.Pattern]]] = get_backend_config().collection_whitelist + collection_allow_list = get_backend_config().collection_allow_list + if collection_whitelist: + _log.warning("Using deprecated collection_whitelist configuration, consider using collection_allow_list.") + assert collection_allow_list is None + collection_allow_list = CollectionAllowList(collection_whitelist) + + if collection_allow_list: + collection_allow_list = CollectionAllowList(collection_allow_list) with TimingLogger(title="Collect collection metadata from all backends", logger=_log): for con in self.backends: @@ -175,8 +236,8 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]: for collection_metadata in backend_collections: if "id" in collection_metadata: collection_id = collection_metadata["id"] - if collection_whitelist: - if is_whitelisted(collection_id, whitelist=collection_whitelist, on_empty=True): + if collection_allow_list: + if collection_allow_list.is_allowed(collection_id=collection_id, backend_id=con.id): _log.debug(f"Preserving whitelisted {collection_id=} from {con.id=}") else: _log.debug(f"Skipping non-whitelisted {collection_id=} from {con.id=}") diff --git a/src/openeo_aggregator/config.py b/src/openeo_aggregator/config.py index 946dd79..30596f5 100644 --- a/src/openeo_aggregator/config.py +++ b/src/openeo_aggregator/config.py @@ -59,8 +59,16 @@ class AggregatorBackendConfig(OpenEoBackendConfig): connections_cache_ttl: float = 5 * 60.0 # List of collection ids to cover with the aggregator (when None: support union of all upstream collections) + # TODO: remove this deprecated field collection_whitelist: Optional[List[Union[str, re.Pattern]]] = None + # List of collection ids to cover with the aggregator. + # By default (or value None): support union of all upstream collections + # Each item can be a string (collection id), regex pattern for collection id, or dict with: + # - required key "collection_id" (string or regex) + # - optional "allowed_backends": list of backends to consider for this collection + collection_allow_list: Optional[List[Union[str, re.Pattern, dict]]] = None + zookeeper_prefix: str = "/openeo-aggregator/" # See `memoizer_from_config` for details. diff --git a/src/openeo_aggregator/utils.py b/src/openeo_aggregator/utils.py index 071489e..62fce86 100644 --- a/src/openeo_aggregator/utils.py +++ b/src/openeo_aggregator/utils.py @@ -314,3 +314,17 @@ def __getattr__(self, name): if name in self.to_track: self.stats[name] = self.stats.get(name, 0) + 1 return getattr(self.target, name) + + +def string_or_regex_match(pattern: Union[str, re.Pattern], value: str) -> bool: + """ + Check if given value matches given pattern. + If pattern is given as string, it must be an exact match. + If pattern is given as regex, it must match the full value. + """ + if isinstance(pattern, str): + return pattern == value + elif isinstance(pattern, re.Pattern): + return bool(pattern.fullmatch(value)) + else: + raise TypeError(f"Invalid pattern type {type(pattern)}") diff --git a/tests/test_views.py b/tests/test_views.py index 3fe0cae..0103097 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -292,7 +292,9 @@ def test_collections_links(self, api100, requests_mock, backend1, backend2): ([re.compile(r".*2")], {"S2"}), ], ) - def test_collections_whitelist(self, api100, requests_mock, backend1, backend2, collection_whitelist, expected): + def test_collections_whitelist_legacy( + self, api100, requests_mock, backend1, backend2, collection_whitelist, expected + ): requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]}) for cid in ["S1", "S2", "S3"]: requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"}) @@ -319,6 +321,104 @@ def test_collections_whitelist(self, api100, requests_mock, backend1, backend2, res = api100.get("/collections/S999") res.assert_error(404, "CollectionNotFound") + @pytest.mark.parametrize( + ["collection_allow_list", "expected"], + [ + (None, {"S1", "S2", "S3", "S4"}), + ([], {"S1", "S2", "S3", "S4"}), + (["S2"], {"S2"}), + (["S4"], {"S4"}), + (["S2", "S3"], {"S2", "S3"}), + (["S2", "S999"], {"S2"}), + (["S999"], set()), + ([re.compile(r"S[23]")], {"S2", "S3"}), + ([re.compile(r"S")], set()), + ([re.compile(r"S.*")], {"S1", "S2", "S3", "S4"}), + ([re.compile(r"S2.*")], {"S2"}), + ([re.compile(r".*2")], {"S2"}), + ], + ) + def test_collections_allow_list(self, api100, requests_mock, backend1, backend2, collection_allow_list, expected): + requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]}) + for cid in ["S1", "S2", "S3"]: + requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"}) + requests_mock.get(backend2 + "/collections", json={"collections": [{"id": "S3"}, {"id": "S4"}]}) + for cid in ["S3", "S4"]: + requests_mock.get(backend2 + f"/collections/{cid}", json={"id": cid, "title": f"b2 {cid}"}) + + with config_overrides(collection_allow_list=collection_allow_list): + res = api100.get("/collections").assert_status_code(200).json + assert set(c["id"] for c in res["collections"]) == expected + + res = api100.get("/collections/S2") + if "S2" in expected: + assert res.assert_status_code(200).json == DictSubSet({"id": "S2", "title": "b1 S2"}) + else: + res.assert_error(404, "CollectionNotFound") + + res = api100.get("/collections/S3") + if "S3" in expected: + assert res.assert_status_code(200).json == DictSubSet({"id": "S3", "title": "b1 S3"}) + else: + res.assert_error(404, "CollectionNotFound") + + res = api100.get("/collections/S999") + res.assert_error(404, "CollectionNotFound") + + @pytest.mark.parametrize( + ["collection_allow_list", "expected"], + [ + ( + ["S2"], + {"S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}}, + ), + ( + ["S2", {"collection_id": "S3"}], + { + "S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + }, + ), + ( + ["S2", {"collection_id": "S3", "allowed_backends": ["b2"]}], + { + "S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})}, + }, + ), + ( + [ + {"collection_id": "S3", "allowed_backends": ["b999"]}, + {"collection_id": "S3", "allowed_backends": ["b2"]}, + ], + { + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})}, + }, + ), + ], + ) + def test_collections_allow_list_allowed_backend( + self, api100, requests_mock, backend1, backend2, collection_allow_list, expected + ): + for bid, cids in { + backend1: ["S1", "S2", "S3"], + backend2: ["S2", "S3", "S4"], + }.items(): + requests_mock.get(bid + "/collections", json={"collections": [{"id": cid} for cid in cids]}) + for cid in cids: + requests_mock.get(bid + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"}) + + with config_overrides(collection_allow_list=collection_allow_list): + res = api100.get("/collections").assert_status_code(200).json + assert set(c["id"] for c in res["collections"]) == set(expected.keys()) + + for cid in ["S1", "S2", "S3", "S4", "S999"]: + res = api100.get(f"/collections/{cid}") + if cid in expected: + assert res.assert_status_code(200).json == DictSubSet(expected[cid]) + else: + res.assert_error(404, "CollectionNotFound") + class TestAuthentication: def test_credentials_oidc_default(self, api100, backend1, backend2):