From 40fb41b38234c98570f4c0d6f52d8146aacbcb2c Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 2 Apr 2024 17:02:41 +0200 Subject: [PATCH] Issue #139 add backend-aware collection allow-list config option --- CHANGELOG.md | 4 ++ src/openeo_aggregator/about.py | 2 +- src/openeo_aggregator/backend.py | 61 +++++++++++++++++- src/openeo_aggregator/config.py | 15 +++++ src/openeo_aggregator/utils.py | 14 +++++ tests/test_backend.py | 32 ++++++++++ tests/test_utils.py | 20 ++++++ tests/test_views.py | 102 ++++++++++++++++++++++++++++++- 8 files changed, 245 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 248d4eb..644c74c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [0.30.0] + +- Add backend-aware collection allow-list option ([#139](https://github.com/Open-EO/openeo-aggregator/issues/139)) + ## [0.29.0] - Add config option to inject job options before sending a job to upstream back-end ([#135](https://github.com/Open-EO/openeo-aggregator/issues/135)) diff --git a/src/openeo_aggregator/about.py b/src/openeo_aggregator/about.py index 1a7a861..b605f82 100644 --- a/src/openeo_aggregator/about.py +++ b/src/openeo_aggregator/about.py @@ -2,7 +2,7 @@ import sys from typing import Optional -__version__ = "0.29.0a1" +__version__ = "0.30.0a1" def log_version_info(logger: Optional[logging.Logger] = None): diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index ca01f21..7bef2aa 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import contextlib +import dataclasses import datetime import functools import logging @@ -106,6 +109,7 @@ dict_merge, is_whitelisted, normalize_issuer_url, + string_or_regex_match, subdict, ) @@ -139,6 +143,54 @@ def __jsonserde_load__(cls, data: dict): return cls(data=data) +@dataclasses.dataclass(frozen=True) +class CollectionAllowItem: + """ + Item in the collection allow list. + At least contains a collection id (string or regex pattern) + and optionally a list of allowed backends. + """ + + collection_id: Union[str, re.Pattern] + allowed_backends: Optional[List[str]] = None + # TODO: support deny list too? + + @staticmethod + def parse(item: Union[str, re.Pattern, dict]) -> CollectionAllowItem: + """Parse given item data""" + if isinstance(item, (str, re.Pattern)): + return CollectionAllowItem(collection_id=item) + elif isinstance(item, dict): + return CollectionAllowItem(**item) + else: + raise TypeError(f"Invalid item type {type(item)}") + + def match(self, collection_id: str, backend_id: str) -> bool: + """Check if given collection/backend pair matches this item""" + collection_ok = string_or_regex_match(pattern=self.collection_id, value=collection_id) + backend_ok = self.allowed_backends is None or backend_id in self.allowed_backends + return collection_ok and backend_ok + + +class CollectionAllowList: + """Allow list for collections, where filtering is based on collection id and (optionally) backend id.""" + + def __init__(self, items: List[Union[str, re.Pattern, dict]]): + """ + :param items: list of allow list items, where each item can be: + - string (collection id) + - regex pattern for collection id + - dict with: + - required key "collection_id" (string or regex) + - optional "allowed_backends": list of backends to consider for this collection + """ + self.items: List[CollectionAllowItem] = [CollectionAllowItem.parse(item) for item in items] + + def is_allowed(self, collection_id: str, backend_id: str) -> bool: + """Check if given collection is allowed""" + return any(item.match(collection_id=collection_id, backend_id=backend_id) for item in self.items) + + class AggregatorCollectionCatalog(AbstractCollectionCatalog): def __init__(self, backends: MultiBackendConnection): self.backends = backends @@ -161,7 +213,10 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]: """ # Group collection metadata by hierarchically: collection id -> backend id -> metadata grouped = defaultdict(dict) - collection_whitelist: Optional[List[Union[str, re.Pattern]]] = get_backend_config().collection_whitelist + # TODO: remove deprecated collection_whitelist usage + collection_allow_list = get_backend_config().collection_allow_list or get_backend_config().collection_whitelist + if collection_allow_list: + collection_allow_list = CollectionAllowList(collection_allow_list) with TimingLogger(title="Collect collection metadata from all backends", logger=_log): for con in self.backends: @@ -175,8 +230,8 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]: for collection_metadata in backend_collections: if "id" in collection_metadata: collection_id = collection_metadata["id"] - if collection_whitelist: - if is_whitelisted(collection_id, whitelist=collection_whitelist, on_empty=True): + if collection_allow_list: + if collection_allow_list.is_allowed(collection_id=collection_id, backend_id=con.id): _log.debug(f"Preserving whitelisted {collection_id=} from {con.id=}") else: _log.debug(f"Skipping non-whitelisted {collection_id=} from {con.id=}") diff --git a/src/openeo_aggregator/config.py b/src/openeo_aggregator/config.py index 946dd79..0800368 100644 --- a/src/openeo_aggregator/config.py +++ b/src/openeo_aggregator/config.py @@ -59,8 +59,23 @@ class AggregatorBackendConfig(OpenEoBackendConfig): connections_cache_ttl: float = 5 * 60.0 # List of collection ids to cover with the aggregator (when None: support union of all upstream collections) + # TODO: remove this deprecated field collection_whitelist: Optional[List[Union[str, re.Pattern]]] = None + # Allow list for collection ids to cover with the aggregator. + # By default (value `None`): support union of all upstream collections. + # To enable a real allow list, use a list of items as illustrated: + # [ + # # Regular string: match exactly + # "COPERNICUS_30", + # # Regex pattern object: match collection id with regex (`fullmatch` mode) + # re.compile(r"CGLS_.*"), + # # Dict: match collection id (again as string or with regex pattern) + # # and additionally only consider specific backends by id (per `aggregator_backends` config) + # {"collection_id": "SENTINEL2_L2A", "allowed_backends": ["b2"]}, + # ] + collection_allow_list: Optional[List[Union[str, re.Pattern, dict]]] = None + zookeeper_prefix: str = "/openeo-aggregator/" # See `memoizer_from_config` for details. diff --git a/src/openeo_aggregator/utils.py b/src/openeo_aggregator/utils.py index 071489e..ebb31e5 100644 --- a/src/openeo_aggregator/utils.py +++ b/src/openeo_aggregator/utils.py @@ -314,3 +314,17 @@ def __getattr__(self, name): if name in self.to_track: self.stats[name] = self.stats.get(name, 0) + 1 return getattr(self.target, name) + + +def string_or_regex_match(pattern: Union[str, re.Pattern], value: str) -> bool: + """ + Check if given value matches given pattern. + If pattern is given as string, it must be an exact match. + If pattern is given as regex, it must match the full value. + """ + if isinstance(pattern, str): + return pattern == value + elif isinstance(pattern, re.Pattern): + return bool(pattern.fullmatch(value)) + else: + raise TypeError(f"Invalid pattern {pattern}") diff --git a/tests/test_backend.py b/tests/test_backend.py index c5ebe76..2952c90 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,5 +1,6 @@ import datetime as dt import logging +import re import pytest from openeo.rest import OpenEoApiError, OpenEoApiPlainError, OpenEoRestError @@ -22,6 +23,7 @@ AggregatorCollectionCatalog, AggregatorProcessing, AggregatorSecondaryServices, + CollectionAllowList, JobIdMapping, _InternalCollectionMetadata, ) @@ -866,6 +868,36 @@ def test_list_backends_for_collections(self): ] +class TestCollectionAllowList: + def test_basic(self): + allow_list = CollectionAllowList( + [ + "foo", + re.compile("ba+r"), + ] + ) + assert allow_list.is_allowed("foo", "b1") is True + assert allow_list.is_allowed("bar", "b1") is True + assert allow_list.is_allowed("baaaaar", "b1") is True + assert allow_list.is_allowed("br", "b1") is False + + def test_allowed_backends(self): + allow_list = CollectionAllowList( + [ + "foo", + {"collection_id": "S2", "allowed_backends": ["b1"]}, + ] + ) + assert allow_list.is_allowed("foo", "b1") is True + assert allow_list.is_allowed("foo", "b2") is True + assert allow_list.is_allowed("S2", "b1") is True + assert allow_list.is_allowed("S2", "b2") is False + + def test_allowed_backends_field_typo(self): + with pytest.raises(TypeError, match="unexpected keyword argument 'backends'"): + _ = CollectionAllowList([{"collection_id": "S2", "backends": ["b1"]}]) + + @pytest.mark.usefixtures("flask_app") # Automatically enter flask app context for `url_for` to work class TestAggregatorCollectionCatalog: def test_get_all_metadata_simple(self, catalog, backend1, backend2, requests_mock): diff --git a/tests/test_utils.py b/tests/test_utils.py index cc158b7..f997bc9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,6 +13,7 @@ drop_dict_keys, is_whitelisted, normalize_issuer_url, + string_or_regex_match, strip_join, subdict, timestamp_to_rfc3339, @@ -343,3 +344,22 @@ def meh(self, x): assert foo.meh(6) == 12 assert foo.stats == {"bar": 1} + + +def test_string_or_regex_match_str(): + assert string_or_regex_match("foo", "foo") is True + assert string_or_regex_match("foo", "bar") is False + + +def test_string_or_regex_match_regex(): + assert string_or_regex_match(re.compile("(foo|bar)"), "foo") is True + assert string_or_regex_match(re.compile("(foo|ba+r)"), "baaar") is True + assert string_or_regex_match(re.compile("(foo|bar)"), "meh") is False + assert string_or_regex_match(re.compile("(foo|bar)"), "foobar") is False + assert string_or_regex_match(re.compile("(foo|bar).*"), "foozuu") is True + assert string_or_regex_match(re.compile(".*(foo|bar)"), "meebar") is True + + +def test_string_or_regex_match_invalid(): + with pytest.raises(TypeError, match=re.escape("Invalid pattern [1, 2, 3]")): + string_or_regex_match([1, 2, 3], "foo") diff --git a/tests/test_views.py b/tests/test_views.py index 3fe0cae..f4eb4a5 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -292,7 +292,9 @@ def test_collections_links(self, api100, requests_mock, backend1, backend2): ([re.compile(r".*2")], {"S2"}), ], ) - def test_collections_whitelist(self, api100, requests_mock, backend1, backend2, collection_whitelist, expected): + def test_collections_whitelist_legacy( + self, api100, requests_mock, backend1, backend2, collection_whitelist, expected + ): requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]}) for cid in ["S1", "S2", "S3"]: requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"}) @@ -319,6 +321,104 @@ def test_collections_whitelist(self, api100, requests_mock, backend1, backend2, res = api100.get("/collections/S999") res.assert_error(404, "CollectionNotFound") + @pytest.mark.parametrize( + ["collection_allow_list", "expected"], + [ + (None, {"S1", "S2", "S3", "S4"}), + ([], {"S1", "S2", "S3", "S4"}), + (["S2"], {"S2"}), + (["S4"], {"S4"}), + (["S2", "S3"], {"S2", "S3"}), + (["S2", "S999"], {"S2"}), + (["S999"], set()), + ([re.compile(r"S[23]")], {"S2", "S3"}), + ([re.compile(r"S")], set()), + ([re.compile(r"S.*")], {"S1", "S2", "S3", "S4"}), + ([re.compile(r"S2.*")], {"S2"}), + ([re.compile(r".*2")], {"S2"}), + ], + ) + def test_collections_allow_list(self, api100, requests_mock, backend1, backend2, collection_allow_list, expected): + requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]}) + for cid in ["S1", "S2", "S3"]: + requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"}) + requests_mock.get(backend2 + "/collections", json={"collections": [{"id": "S3"}, {"id": "S4"}]}) + for cid in ["S3", "S4"]: + requests_mock.get(backend2 + f"/collections/{cid}", json={"id": cid, "title": f"b2 {cid}"}) + + with config_overrides(collection_allow_list=collection_allow_list): + res = api100.get("/collections").assert_status_code(200).json + assert set(c["id"] for c in res["collections"]) == expected + + res = api100.get("/collections/S2") + if "S2" in expected: + assert res.assert_status_code(200).json == DictSubSet({"id": "S2", "title": "b1 S2"}) + else: + res.assert_error(404, "CollectionNotFound") + + res = api100.get("/collections/S3") + if "S3" in expected: + assert res.assert_status_code(200).json == DictSubSet({"id": "S3", "title": "b1 S3"}) + else: + res.assert_error(404, "CollectionNotFound") + + res = api100.get("/collections/S999") + res.assert_error(404, "CollectionNotFound") + + @pytest.mark.parametrize( + ["collection_allow_list", "expected"], + [ + ( + ["S2"], + {"S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}}, + ), + ( + ["S2", {"collection_id": "S3"}], + { + "S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + }, + ), + ( + ["S2", {"collection_id": "S3", "allowed_backends": ["b2"]}], + { + "S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}, + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})}, + }, + ), + ( + [ + {"collection_id": "S3", "allowed_backends": ["b999"]}, + {"collection_id": "S3", "allowed_backends": ["b2"]}, + ], + { + "S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})}, + }, + ), + ], + ) + def test_collections_allow_list_allowed_backend( + self, api100, requests_mock, backend1, backend2, collection_allow_list, expected + ): + for backend, cids in { + backend1: ["S1", "S2", "S3"], + backend2: ["S2", "S3", "S4"], + }.items(): + requests_mock.get(backend + "/collections", json={"collections": [{"id": cid} for cid in cids]}) + for cid in cids: + requests_mock.get(backend + f"/collections/{cid}", json={"id": cid}) + + with config_overrides(collection_allow_list=collection_allow_list): + res = api100.get("/collections").assert_status_code(200).json + assert set(c["id"] for c in res["collections"]) == set(expected.keys()) + + for cid in ["S1", "S2", "S3", "S4", "S999"]: + res = api100.get(f"/collections/{cid}") + if cid in expected: + assert res.assert_status_code(200).json == DictSubSet(expected[cid]) + else: + res.assert_error(404, "CollectionNotFound") + class TestAuthentication: def test_credentials_oidc_default(self, api100, backend1, backend2):