Skip to content

Commit

Permalink
Issue #139 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 2, 2024
1 parent 8c586ec commit 3acfa15
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 3 deletions.
65 changes: 63 additions & 2 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import contextlib
import dataclasses
import datetime
import functools
import logging
Expand Down Expand Up @@ -106,6 +109,7 @@
dict_merge,
is_whitelisted,
normalize_issuer_url,
string_or_regex_match,
subdict,
)

Expand Down Expand Up @@ -139,6 +143,54 @@ def __jsonserde_load__(cls, data: dict):
return cls(data=data)


@dataclasses.dataclass(frozen=True)
class CollectionAllowItem:
"""
Item in the collection allow list.
At least contains a collection id (string or regex pattern)
and optionally a list of allowed backends.
"""

collection_id: Union[str, re.Pattern]
allowed_backends: Optional[List[str]] = None
# TODO: support deny list too?

@staticmethod
def parse(item: Union[str, re.Pattern, dict]) -> CollectionAllowItem:
"""Parse given item data"""
if isinstance(item, (str, re.Pattern)):
return CollectionAllowItem(collection_id=item)
elif isinstance(item, dict):
return CollectionAllowItem(**item)
else:
raise TypeError(f"Invalid item type {type(item)}")

def match(self, collection_id: str, backend_id: str) -> bool:
"""Check if given collection/backend pair matches this item"""
collection_ok = string_or_regex_match(pattern=self.collection_id, value=collection_id)
backend_ok = self.allowed_backends is None or backend_id in self.allowed_backends
return collection_ok and backend_ok


class CollectionAllowList:
"""Allow list for collections, where filtering is based on collection id and (optionally) backend id."""

def __init__(self, items: List[Union[str, re.Pattern, dict]]):
"""
:param items: list of allow list items, where each item can be:
- string (collection id)
- regex pattern for collection id
- dict with:
- required key "collection_id" (string or regex)
- optional "allowed_backends": list of backends to consider for this collection
"""
self.items: List[CollectionAllowItem] = [CollectionAllowItem.parse(item) for item in items]

def is_allowed(self, collection_id: str, backend_id: str) -> bool:
"""Check if given collection is allowed"""
return any(item.match(collection_id=collection_id, backend_id=backend_id) for item in self.items)


class AggregatorCollectionCatalog(AbstractCollectionCatalog):
def __init__(self, backends: MultiBackendConnection):
self.backends = backends
Expand All @@ -161,7 +213,16 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]:
"""
# Group collection metadata by hierarchically: collection id -> backend id -> metadata
grouped = defaultdict(dict)
# TODO: remove this deprecated collection_whitelist
collection_whitelist: Optional[List[Union[str, re.Pattern]]] = get_backend_config().collection_whitelist
collection_allow_list = get_backend_config().collection_allow_list
if collection_whitelist:
_log.warning("Using deprecated collection_whitelist configuration, consider using collection_allow_list.")
assert collection_allow_list is None
collection_allow_list = CollectionAllowList(collection_whitelist)

if collection_allow_list:
collection_allow_list = CollectionAllowList(collection_allow_list)

with TimingLogger(title="Collect collection metadata from all backends", logger=_log):
for con in self.backends:
Expand All @@ -175,8 +236,8 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]:
for collection_metadata in backend_collections:
if "id" in collection_metadata:
collection_id = collection_metadata["id"]
if collection_whitelist:
if is_whitelisted(collection_id, whitelist=collection_whitelist, on_empty=True):
if collection_allow_list:
if collection_allow_list.is_allowed(collection_id=collection_id, backend_id=con.id):
_log.debug(f"Preserving whitelisted {collection_id=} from {con.id=}")
else:
_log.debug(f"Skipping non-whitelisted {collection_id=} from {con.id=}")
Expand Down
8 changes: 8 additions & 0 deletions src/openeo_aggregator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,16 @@ class AggregatorBackendConfig(OpenEoBackendConfig):
connections_cache_ttl: float = 5 * 60.0

# List of collection ids to cover with the aggregator (when None: support union of all upstream collections)
# TODO: remove this deprecated field
collection_whitelist: Optional[List[Union[str, re.Pattern]]] = None

# List of collection ids to cover with the aggregator.
# By default (or value None): support union of all upstream collections
# Each item can be a string (collection id), regex pattern for collection id, or dict with:
# - required key "collection_id" (string or regex)
# - optional "allowed_backends": list of backends to consider for this collection
collection_allow_list: Optional[List[Union[str, re.Pattern, dict]]] = None

zookeeper_prefix: str = "/openeo-aggregator/"

# See `memoizer_from_config` for details.
Expand Down
14 changes: 14 additions & 0 deletions src/openeo_aggregator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,17 @@ def __getattr__(self, name):
if name in self.to_track:
self.stats[name] = self.stats.get(name, 0) + 1
return getattr(self.target, name)


def string_or_regex_match(pattern: Union[str, re.Pattern], value: str) -> bool:
"""
Check if given value matches given pattern.
If pattern is given as string, it must be an exact match.
If pattern is given as regex, it must match the full value.
"""
if isinstance(pattern, str):
return pattern == value
elif isinstance(pattern, re.Pattern):
return bool(pattern.fullmatch(value))
else:
raise TypeError(f"Invalid pattern type {type(pattern)}")
102 changes: 101 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def test_collections_links(self, api100, requests_mock, backend1, backend2):
([re.compile(r".*2")], {"S2"}),
],
)
def test_collections_whitelist(self, api100, requests_mock, backend1, backend2, collection_whitelist, expected):
def test_collections_whitelist_legacy(
self, api100, requests_mock, backend1, backend2, collection_whitelist, expected
):
requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]})
for cid in ["S1", "S2", "S3"]:
requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"})
Expand All @@ -319,6 +321,104 @@ def test_collections_whitelist(self, api100, requests_mock, backend1, backend2,
res = api100.get("/collections/S999")
res.assert_error(404, "CollectionNotFound")

@pytest.mark.parametrize(
["collection_allow_list", "expected"],
[
(None, {"S1", "S2", "S3", "S4"}),
([], {"S1", "S2", "S3", "S4"}),
(["S2"], {"S2"}),
(["S4"], {"S4"}),
(["S2", "S3"], {"S2", "S3"}),
(["S2", "S999"], {"S2"}),
(["S999"], set()),
([re.compile(r"S[23]")], {"S2", "S3"}),
([re.compile(r"S")], set()),
([re.compile(r"S.*")], {"S1", "S2", "S3", "S4"}),
([re.compile(r"S2.*")], {"S2"}),
([re.compile(r".*2")], {"S2"}),
],
)
def test_collections_allow_list(self, api100, requests_mock, backend1, backend2, collection_allow_list, expected):
requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S1"}, {"id": "S2"}, {"id": "S3"}]})
for cid in ["S1", "S2", "S3"]:
requests_mock.get(backend1 + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"})
requests_mock.get(backend2 + "/collections", json={"collections": [{"id": "S3"}, {"id": "S4"}]})
for cid in ["S3", "S4"]:
requests_mock.get(backend2 + f"/collections/{cid}", json={"id": cid, "title": f"b2 {cid}"})

with config_overrides(collection_allow_list=collection_allow_list):
res = api100.get("/collections").assert_status_code(200).json
assert set(c["id"] for c in res["collections"]) == expected

res = api100.get("/collections/S2")
if "S2" in expected:
assert res.assert_status_code(200).json == DictSubSet({"id": "S2", "title": "b1 S2"})
else:
res.assert_error(404, "CollectionNotFound")

res = api100.get("/collections/S3")
if "S3" in expected:
assert res.assert_status_code(200).json == DictSubSet({"id": "S3", "title": "b1 S3"})
else:
res.assert_error(404, "CollectionNotFound")

res = api100.get("/collections/S999")
res.assert_error(404, "CollectionNotFound")

@pytest.mark.parametrize(
["collection_allow_list", "expected"],
[
(
["S2"],
{"S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})}},
),
(
["S2", {"collection_id": "S3"}],
{
"S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})},
"S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})},
},
),
(
["S2", {"collection_id": "S3", "allowed_backends": ["b2"]}],
{
"S2": {"id": "S2", "summaries": DictSubSet({"federation:backends": ["b1", "b2"]})},
"S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})},
},
),
(
[
{"collection_id": "S3", "allowed_backends": ["b999"]},
{"collection_id": "S3", "allowed_backends": ["b2"]},
],
{
"S3": {"id": "S3", "summaries": DictSubSet({"federation:backends": ["b2"]})},
},
),
],
)
def test_collections_allow_list_allowed_backend(
self, api100, requests_mock, backend1, backend2, collection_allow_list, expected
):
for bid, cids in {
backend1: ["S1", "S2", "S3"],
backend2: ["S2", "S3", "S4"],
}.items():
requests_mock.get(bid + "/collections", json={"collections": [{"id": cid} for cid in cids]})
for cid in cids:
requests_mock.get(bid + f"/collections/{cid}", json={"id": cid, "title": f"b1 {cid}"})

with config_overrides(collection_allow_list=collection_allow_list):
res = api100.get("/collections").assert_status_code(200).json
assert set(c["id"] for c in res["collections"]) == set(expected.keys())

for cid in ["S1", "S2", "S3", "S4", "S999"]:
res = api100.get(f"/collections/{cid}")
if cid in expected:
assert res.assert_status_code(200).json == DictSubSet(expected[cid])
else:
res.assert_error(404, "CollectionNotFound")


class TestAuthentication:
def test_credentials_oidc_default(self, api100, backend1, backend2):
Expand Down

0 comments on commit 3acfa15

Please sign in to comment.