diff --git a/setup.py b/setup.py index a3327408..15e2206b 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ exec(fp.read()) tests_require = [ - 'pytest>=6.2.0', - 'requests-mock>=1.9.0', + "pytest>=6.2.0", + "requests-mock>=1.9.0", ] setup( @@ -52,6 +52,6 @@ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Development Status :: 3 - Alpha", - "Operating System :: OS Independent" - ] + "Operating System :: OS Independent", + ], ) diff --git a/src/openeo_aggregator/app.py b/src/openeo_aggregator/app.py index 8e9d648f..0cf0a098 100644 --- a/src/openeo_aggregator/app.py +++ b/src/openeo_aggregator/app.py @@ -37,7 +37,6 @@ def create_app(auto_logging_setup: bool = True, flask_error_handling: bool = Tru log_version_info(logger=_log) - backends = MultiBackendConnection.from_config() _log.info("Creating AggregatorBackendImplementation") @@ -59,9 +58,7 @@ def create_app(auto_logging_setup: bool = True, flask_error_handling: bool = Tru @app.route("/_info", methods=["GET"]) def agg_backends(): - info = { - "backends": [{"id": con.id, "root_url": con.root_url} for con in backends] - } + info = {"backends": [{"id": con.id, "root_url": con.root_url} for con in backends]} return flask.jsonify(info) _log.info(f"Built {app=!r}") diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index 3b575900..0dcd132d 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -140,7 +140,6 @@ def __jsonserde_load__(cls, data: dict): class AggregatorCollectionCatalog(AbstractCollectionCatalog): - def __init__(self, backends: MultiBackendConnection): self.backends = backends self._memoizer = memoizer_from_config(namespace="CollectionCatalog") @@ -195,14 +194,12 @@ def _get_all_metadata(self) -> Tuple[List[dict], _InternalCollectionMetadata]: if len(by_backend) == 1: # Simple case: collection is only available on single backend. _log.debug(f"Accept single backend collection {cid} as is") - (bid, metadata), = by_backend.items() + ((bid, metadata),) = by_backend.items() single_backend_collection_post_processing(metadata, backend_id=bid) else: _log.info(f"Merging {cid!r} collection metadata from backends {by_backend.keys()}") try: - metadata = merge_collection_metadata( - by_backend, full_metadata=False - ) + metadata = merge_collection_metadata(by_backend, full_metadata=False) except Exception as e: _log.error(f"Failed to merge collection metadata for {cid!r}", exc_info=True) continue @@ -295,9 +292,7 @@ def _get_collection_metadata(self, collection_id: str) -> dict: single_backend_collection_post_processing(metadata, backend_id=bid) else: _log.info(f"Merging metadata for collection {collection_id}.") - metadata = merge_collection_metadata( - by_backend=by_backend, full_metadata=True - ) + metadata = merge_collection_metadata(by_backend=by_backend, full_metadata=True) return normalize_collection_metadata(metadata) def load_collection(self, collection_id: str, load_params: LoadParameters, env: EvalEnv) -> DriverDataCube: @@ -340,9 +335,9 @@ def parse_aggregator_job_id(cls, backends: MultiBackendConnection, aggregator_jo class AggregatorProcessing(Processing): def __init__( - self, - backends: MultiBackendConnection, - catalog: AggregatorCollectionCatalog, + self, + backends: MultiBackendConnection, + catalog: AggregatorCollectionCatalog, ): self.backends = backends # TODO Cache per backend results instead of output? @@ -350,17 +345,10 @@ def __init__( self.backends.on_connections_change.add(self._memoizer.invalidate) self._catalog = catalog - def get_process_registry( - self, api_version: Union[str, ComparableVersion] - ) -> ProcessRegistry: - if ( - api_version < self.backends.api_version_minimum - or api_version > self.backends.api_version_maximum - ): + def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> ProcessRegistry: + if api_version < self.backends.api_version_minimum or api_version > self.backends.api_version_maximum: # TODO: more relaxed version check? How useful is this check anyway? - _log.warning( - f"API mismatch: requested {api_version} outside of {self.backends.get_api_versions()}" - ) + _log.warning(f"API mismatch: requested {api_version} outside of {self.backends.get_api_versions()}") combined_processes = self.get_merged_process_metadata() process_registry = ProcessRegistry() @@ -388,9 +376,7 @@ def _get_merged_process_metadata(self) -> dict: ) return combined_processes - def _get_backend_candidates_for_processes( - self, processes: typing.Collection[str] - ) -> Union[List[str], None]: + def _get_backend_candidates_for_processes(self, processes: typing.Collection[str]) -> Union[List[str], None]: """ Get backend ids providing all given processes :param processes: collection process ids @@ -407,9 +393,7 @@ def _get_backend_candidates_for_processes( else: candidates = candidates.intersection(backends) else: - _log.warning( - f"Skipping unknown process {pid!r} in `_get_backend_candidates_for_processes`" - ) + _log.warning(f"Skipping unknown process {pid!r} in `_get_backend_candidates_for_processes`") return candidates def get_backend_for_process_graph( @@ -464,9 +448,7 @@ def get_backend_for_process_graph( backends=self.backends, aggregator_job_id=arguments["id"], ) - backend_candidates = [ - b for b in backend_candidates if b == job_backend_id - ] + backend_candidates = [b for b in backend_candidates if b == job_backend_id] elif process_id == "load_ml_model": model_backend_id = self._process_load_ml_model(arguments)[0] if model_backend_id: @@ -484,26 +466,22 @@ def get_backend_for_process_graph( conditions = self._catalog.generate_backend_constraint_callables( process_graphs=collection_backend_constraints ) - backend_candidates = [ - b for b in backend_candidates if all(c(b) for c in conditions) - ] + backend_candidates = [b for b in backend_candidates if all(c(b) for c in conditions)] if processes: process_candidates = self._get_backend_candidates_for_processes(processes) if process_candidates: - backend_candidates = [ - b for b in backend_candidates if b in process_candidates - ] + backend_candidates = [b for b in backend_candidates if b in process_candidates] else: # TODO: make this an exception like we do with collections? (BackendLookupFailureException) _log.warning(f"No process based backend candidates ({processes=})") if len(backend_candidates) > 1: - # TODO #42 Check `/validation` instead of naively picking first one? - _log.warning( - f"Multiple back-end candidates {backend_candidates} for collections {collections}." - f" Naively picking first one." - ) + # TODO #42 Check `/validation` instead of naively picking first one? + _log.warning( + f"Multiple back-end candidates {backend_candidates} for collections {collections}." + f" Naively picking first one." + ) if not backend_candidates: raise BackendLookupFailureException(message="No backend matching all constraints") @@ -524,8 +502,10 @@ def evaluate(self, process_graph: dict, env: EvalEnv = None): with con.authenticated_from_request(flask.request), timing_logger: try: backend_response = con.post( - path="/result", json=request_pg, - stream=True, timeout=CONNECTION_TIMEOUT_RESULT, + path="/result", + json=request_pg, + stream=True, + timeout=CONNECTION_TIMEOUT_RESULT, expected_status=200, ) except Exception as e: @@ -551,13 +531,9 @@ def preprocess(node: Any) -> Any: backends=self.backends, aggregator_job_id=result_id, ) - assert ( - job_backend_id == backend_id - ), f"{job_backend_id} != {backend_id}" + assert job_backend_id == backend_id, f"{job_backend_id} != {backend_id}" # Create new load_result node dict with updated job id - return dict_merge( - node, arguments=dict_merge(arguments, id=job_id) - ) + return dict_merge(node, arguments=dict_merge(arguments, id=job_id)) if process_id == "load_ml_model": model_id = self._process_load_ml_model(arguments, expected_backend=backend_id)[1] if model_id: @@ -570,15 +546,14 @@ def preprocess(node: Any) -> Any: return preprocess(process_graph) def _process_load_ml_model( - self, arguments: dict, expected_backend: Optional[str] = None + self, arguments: dict, expected_backend: Optional[str] = None ) -> Tuple[Union[str, None], str]: """Handle load_ml_model: detect/strip backend_id from model_id if it is a job_id""" model_id = arguments.get("id") if model_id and not model_id.startswith("http"): # TODO: load_ml_model's `id` could also be file path (see https://github.com/Open-EO/openeo-processes/issues/384) job_id, job_backend_id = JobIdMapping.parse_aggregator_job_id( - backends=self.backends, - aggregator_job_id=model_id + backends=self.backends, aggregator_job_id=model_id ) if expected_backend and job_backend_id != expected_backend: raise BackendLookupFailureException(f"{job_backend_id} != {expected_backend}") @@ -632,9 +607,7 @@ def validate(self, process_graph: dict, env: EvalEnv = None) -> List[dict]: return errors - class AggregatorBatchJobs(BatchJobs): - def __init__( self, *, @@ -712,12 +685,12 @@ def create_job( ) else: return self._create_partitioned_job( - user_id=user_id, - process=process, - api_version=api_version, - metadata=metadata, - job_options=job_options, - ) + user_id=user_id, + process=process, + api_version=api_version, + metadata=metadata, + job_options=job_options, + ) else: return self._create_job_standard( user_id=user_id, @@ -741,24 +714,23 @@ def _create_job_standard( api_version=api_version, job_options=job_options, ) - process_graph = self.processing.preprocess_process_graph( - process_graph, backend_id=backend_id - ) + process_graph = self.processing.preprocess_process_graph(process_graph, backend_id=backend_id) if job_options: - additional = { - k: v for k, v in job_options.items() if not k.startswith("_agg_") - } + additional = {k: v for k, v in job_options.items() if not k.startswith("_agg_")} else: additional = None con = self.backends.get_connection(backend_id) - with con.authenticated_from_request(request=flask.request, user=User(user_id=user_id)), \ - con.override(default_timeout=CONNECTION_TIMEOUT_JOB_START): + with con.authenticated_from_request(request=flask.request, user=User(user_id=user_id)), con.override( + default_timeout=CONNECTION_TIMEOUT_JOB_START + ): try: job = con.create_job( process_graph=process_graph, - title=metadata.get("title"), description=metadata.get("description"), - plan=metadata.get("plan"), budget=metadata.get("budget"), + title=metadata.get("title"), + description=metadata.get("description"), + plan=metadata.get("plan"), + budget=metadata.get("budget"), additional=additional, ) except OpenEoApiError as e: @@ -849,12 +821,10 @@ def backend_for_collection(collection_id) -> str: ) def _get_connection_and_backend_job_id( - self, - aggregator_job_id: str + self, aggregator_job_id: str ) -> Tuple[Union[BackendConnection, PartitionedJobConnection], str]: backend_job_id, backend_id = JobIdMapping.parse_aggregator_job_id( - backends=self.backends, - aggregator_job_id=aggregator_job_id + backends=self.backends, aggregator_job_id=aggregator_job_id ) if backend_id == JobIdMapping.AGG and self.partitioned_job_tracker: @@ -878,54 +848,53 @@ def _translate_job_errors(self, job_id): def get_job_info(self, job_id: str, user_id: str) -> BatchJobMetadata: con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) user = User(user_id=user_id) - with con.authenticated_from_request(request=flask.request, user=user), \ - self._translate_job_errors(job_id=job_id): + with con.authenticated_from_request(request=flask.request, user=user), self._translate_job_errors( + job_id=job_id + ): metadata = con.job(backend_job_id).describe_job() metadata["id"] = job_id return BatchJobMetadata.from_api_dict(metadata) def start_job(self, job_id: str, user: User): con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) - with con.authenticated_from_request(request=flask.request, user=user), \ - con.override(default_timeout=CONNECTION_TIMEOUT_JOB_START), \ - self._translate_job_errors(job_id=job_id): + with con.authenticated_from_request(request=flask.request, user=user), con.override( + default_timeout=CONNECTION_TIMEOUT_JOB_START + ), self._translate_job_errors(job_id=job_id): con.job(backend_job_id).start_job() def cancel_job(self, job_id: str, user_id: str): con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) - with con.authenticated_from_request(request=flask.request, user=User(user_id)), \ - self._translate_job_errors(job_id=job_id): + with con.authenticated_from_request(request=flask.request, user=User(user_id)), self._translate_job_errors( + job_id=job_id + ): con.job(backend_job_id).stop_job() def delete_job(self, job_id: str, user_id: str): con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) - with con.authenticated_from_request(request=flask.request, user=User(user_id)), \ - self._translate_job_errors(job_id=job_id): + with con.authenticated_from_request(request=flask.request, user=User(user_id)), self._translate_job_errors( + job_id=job_id + ): con.job(backend_job_id).delete_job() def get_result_assets(self, job_id: str, user_id: str) -> Dict[str, dict]: con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) - with con.authenticated_from_request(request=flask.request, user=User(user_id)), \ - self._translate_job_errors(job_id=job_id): + with con.authenticated_from_request(request=flask.request, user=User(user_id)), self._translate_job_errors( + job_id=job_id + ): results = con.job(backend_job_id).get_results() assets = results.get_assets() return {a.name: {**a.metadata, **{BatchJobs.ASSET_PUBLIC_HREF: a.href}} for a in assets} def get_result_metadata(self, job_id: str, user_id: str) -> BatchJobResultMetadata: - con, backend_job_id = self._get_connection_and_backend_job_id( - aggregator_job_id=job_id - ) - with con.authenticated_from_request( - request=flask.request, user=User(user_id) - ), self._translate_job_errors(job_id=job_id): + con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) + with con.authenticated_from_request(request=flask.request, user=User(user_id)), self._translate_job_errors( + job_id=job_id + ): results = con.job(backend_job_id).get_results() metadata = results.get_metadata() assets = results.get_assets() - assets = { - a.name: {**a.metadata, **{BatchJobs.ASSET_PUBLIC_HREF: a.href}} - for a in assets - } + assets = {a.name: {**a.metadata, **{BatchJobs.ASSET_PUBLIC_HREF: a.href}} for a in assets} # TODO: better white/black list for links? links = [k for k in metadata.get("links", []) if k.get("rel") != "self"] return BatchJobResultMetadata( @@ -942,10 +911,11 @@ def get_log_entries( ) -> Iterable[dict]: con, backend_job_id = self._get_connection_and_backend_job_id(aggregator_job_id=job_id) # Use parenthesized context managers, see #127 - with con.authenticated_from_request(request=flask.request, user=User(user_id)), \ - self._translate_job_errors(job_id=job_id), \ - con.override(default_timeout=CONNECTION_TIMEOUT_JOB_LOGS), \ - TimingLogger(title=f"Get log entries for {job_id}", logger=_log.debug): + with con.authenticated_from_request(request=flask.request, user=User(user_id)), self._translate_job_errors( + job_id=job_id + ), con.override(default_timeout=CONNECTION_TIMEOUT_JOB_LOGS), TimingLogger( + title=f"Get log entries for {job_id}", logger=_log.debug + ): return con.job(backend_job_id).logs(offset=offset, level=level) @@ -958,7 +928,9 @@ def get_aggregator_service_id(backend_service_id: str, backend_id: str) -> str: return f"{backend_id}-{backend_service_id}" @classmethod - def parse_aggregator_service_id(cls, backends: MultiBackendConnection, aggregator_service_id: str) -> Tuple[str, str]: + def parse_aggregator_service_id( + cls, backends: MultiBackendConnection, aggregator_service_id: str + ) -> Tuple[str, str]: """Given aggregator service id: extract backend service id and backend id""" for prefix in [f"{con.id}-" for con in backends]: if aggregator_service_id.startswith(prefix): @@ -974,9 +946,9 @@ class AggregatorSecondaryServices(SecondaryServices): """ def __init__( - self, - backends: MultiBackendConnection, - processing: AggregatorProcessing, + self, + backends: MultiBackendConnection, + processing: AggregatorProcessing, ): super(AggregatorSecondaryServices, self).__init__() @@ -986,17 +958,13 @@ def __init__( self._processing = processing - def _get_connection_and_backend_service_id( - self, - aggregator_service_id: str - ) -> Tuple[BackendConnection, str]: + def _get_connection_and_backend_service_id(self, aggregator_service_id: str) -> Tuple[BackendConnection, str]: """Get connection to the backend and the corresponding service ID in that backend. raises: ServiceNotFoundException when service_id does not exist in any of the backends. """ backend_service_id, backend_id = ServiceIdMapping.parse_aggregator_service_id( - backends=self._backends, - aggregator_service_id=aggregator_service_id + backends=self._backends, aggregator_service_id=aggregator_service_id ) con = self._backends.get_connection(backend_id) @@ -1017,9 +985,7 @@ def service_types(self) -> dict: return {name: data["service_type"] for name, data, in service_types.items()} def _get_service_types_cached(self): - return self._memoizer.get_or_call( - key="service_types", callback=self._get_service_types - ) + return self._memoizer.get_or_call(key="service_types", callback=self._get_service_types) def _find_backend_id_for_service_type(self, service_type: str) -> str: """Returns the ID of the backend that provides the service_type.""" @@ -1089,9 +1055,7 @@ def _get_service_types(self) -> Dict: # so we can cache that information. # Some backends don not have the GET /service_types endpoint. supporting_backend_ids = [ - con.id - for con in self._backends - if con.capabilities().supports_endpoint("/service_types") + con.id for con in self._backends if con.capabilities().supports_endpoint("/service_types") ] service_types = {} @@ -1112,8 +1076,8 @@ def _get_service_types(self) -> Dict: else: conflicting_backend = service_types[name]["backend_id"] _log.warning( - f'Conflicting secondary service types: "{name}" is present in more than one backend, ' + - f'already found in backend: {conflicting_backend}' + f'Conflicting secondary service types: "{name}" is present in more than one backend, ' + + f"already found in backend: {conflicting_backend}" ) return { "supporting_backend_ids": supporting_backend_ids, @@ -1126,9 +1090,7 @@ def list_services(self, user_id: str) -> List[ServiceMetadata]: for backend_id in self.get_supporting_backend_ids(): con = self._backends.get_connection(backend_id) - with con.authenticated_from_request( - request=flask.request, user=User(user_id) - ): + with con.authenticated_from_request(request=flask.request, user=User(user_id)): try: data = con.get("/services").json() for service_data in data["services"]: @@ -1159,15 +1121,19 @@ def service_info(self, user_id: str, service_id: str) -> ServiceMetadata: raise ServiceNotFoundException(service_id=service_id) from e raise except Exception as e: - _log.debug(f"Failed to get service with ID={backend_service_id} from backend with ID={con.id}: {e!r}", exc_info=True) + _log.debug( + f"Failed to get service with ID={backend_service_id} from backend with ID={con.id}: {e!r}", + exc_info=True, + ) raise else: # Adapt the service ID so it points to the aggregator, with the backend ID included. service_json["id"] = ServiceIdMapping.get_aggregator_service_id(service_json["id"], con.id) return ServiceMetadata.from_dict(service_json) - def _create_service(self, user_id: str, process_graph: dict, service_type: str, api_version: str, - configuration: dict) -> str: + def _create_service( + self, user_id: str, process_graph: dict, service_type: str, api_version: str, configuration: dict + ) -> str: """ https://openeo.org/documentation/1.0/developers/api/reference.html#operation/create-service """ @@ -1228,7 +1194,9 @@ def update_service(self, user_id: str, service_id: str, process_graph: dict) -> except OpenEoApiPlainError as e: if e.http_status_code == 404: # Expected error - _log.debug(f"No service with ID={backend_service_id!r} in backend with ID={con.id!r}: {e!r}", exc_info=True) + _log.debug( + f"No service with ID={backend_service_id!r} in backend with ID={con.id!r}: {e!r}", exc_info=True + ) raise ServiceNotFoundException(service_id=service_id) from e raise except Exception as e: @@ -1314,7 +1282,8 @@ def __init__(self, backends: MultiBackendConnection): # Shorter HTTP cache TTL to adapt quicker to changed back-end configurations self.cache_control = openeo_driver.util.view_helpers.cache_control( - max_age=datetime.timedelta(minutes=15), public=True, + max_age=datetime.timedelta(minutes=15), + public=True, ) def oidc_providers(self) -> List[OidcProvider]: @@ -1352,42 +1321,48 @@ def user_access_validation(self, user: User, request: flask.Request) -> User: if self._auth_entitlement_check: int_data = user.internal_auth_data issuer_whitelist = [ - normalize_issuer_url(u) - for u in self._auth_entitlement_check.get("oidc_issuer_whitelist", []) + normalize_issuer_url(u) for u in self._auth_entitlement_check.get("oidc_issuer_whitelist", []) ] # TODO: all this is openEO platform EGI VO specific. Should/Can this be generalized/encapsulated better? if not ( - int_data["authentication_method"] == "OIDC" - and normalize_issuer_url(int_data["oidc_issuer"]) in issuer_whitelist + int_data["authentication_method"] == "OIDC" + and normalize_issuer_url(int_data["oidc_issuer"]) in issuer_whitelist ): user_message = "An EGI account is required for using openEO Platform." - _log.warning(f"user_access_validation failure: %r %r", user_message, { - "internal_auth_data": subdict(int_data, keys=["authentication_method", "oidc_issuer"]), - "issuer_whitelist": issuer_whitelist, - }) + _log.warning( + f"user_access_validation failure: %r %r", + user_message, + { + "internal_auth_data": subdict(int_data, keys=["authentication_method", "oidc_issuer"]), + "issuer_whitelist": issuer_whitelist, + }, + ) raise PermissionsInsufficientException(user_message) enrollment_error_user_message = "Proper enrollment in openEO Platform virtual organization is required." try: eduperson_entitlements = user.info["oidc_userinfo"]["eduperson_entitlement"] except KeyError as e: - _log.warning(f"user_access_validation failure: %r %r", enrollment_error_user_message, { - "exception": repr(e), - # Note: just log userinfo keys to avoid leaking sensitive user data. - "userinfo keys": (user.info.keys(), user.info.get('oidc_userinfo', {}).keys()) - }) + _log.warning( + f"user_access_validation failure: %r %r", + enrollment_error_user_message, + { + "exception": repr(e), + # Note: just log userinfo keys to avoid leaking sensitive user data. + "userinfo keys": (user.info.keys(), user.info.get("oidc_userinfo", {}).keys()), + }, + ) raise PermissionsInsufficientException(enrollment_error_user_message) - roles = openeo_aggregator.egi.OPENEO_PLATFORM_USER_ROLES.extract_roles( - eduperson_entitlements - ) + roles = openeo_aggregator.egi.OPENEO_PLATFORM_USER_ROLES.extract_roles(eduperson_entitlements) if roles: user.add_roles(r.id for r in roles) else: - _log.warning(f"user_access_validation failure: %r %r", enrollment_error_user_message, { - "user_id": user.user_id, - "eduperson_entitlements": eduperson_entitlements - }) + _log.warning( + f"user_access_validation failure: %r %r", + enrollment_error_user_message, + {"user_id": user.user_id, "eduperson_entitlements": eduperson_entitlements}, + ) raise PermissionsInsufficientException(enrollment_error_user_message) return user @@ -1415,10 +1390,12 @@ def health_check(self, options: Optional[dict] = None) -> Union[str, dict, flask backend_status[con.id]["error_time"] = self._clock() - start_time overall_status_code = 500 - response = flask.jsonify({ - "status_code": overall_status_code, - "backend_status": backend_status, - }) + response = flask.jsonify( + { + "status_code": overall_status_code, + "backend_status": backend_status, + } + ) response.status_code = overall_status_code return response diff --git a/src/openeo_aggregator/caching.py b/src/openeo_aggregator/caching.py index 68ccc1a2..72d693a2 100644 --- a/src/openeo_aggregator/caching.py +++ b/src/openeo_aggregator/caching.py @@ -88,8 +88,7 @@ def get_or_call(self, key, callback, ttl=None, log_on_miss=False): else: if log_on_miss: with TimingLogger( - title=f"Cache miss {self.name!r} key {key!r}, calling {callback.__qualname__!r}", - logger=_log.debug + title=f"Cache miss {self.name!r} key {key!r}, calling {callback.__qualname__!r}", logger=_log.debug ): res = callback() else: @@ -117,6 +116,7 @@ class Memoizer(metaclass=abc.ABCMeta): Concrete classes should just implement `get_or_call` and `invalidate`. """ + log_on_miss = True def __init__(self, namespace: str = DEFAULT_NAMESPACE): @@ -226,11 +226,8 @@ def _default(self, o: Any) -> dict: """Implementation of `default` parameter of `json.dump` and related""" if o.__class__ in self._custom_types: # TODO: also add signing with a secret? - return {"_jsonserde": { - "type": self._type_id(o.__class__), - "data": o.__jsonserde_prepare__() - }} - raise TypeError(f'Object of type {o.__class__.__name__} is not JSON serializable') + return {"_jsonserde": {"type": self._type_id(o.__class__), "data": o.__jsonserde_prepare__()}} + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") def _object_hook(self, d: dict) -> Any: """Implementation of `object_hook` parameter of `json.load` and related""" @@ -243,7 +240,7 @@ def serialize(self, data: dict) -> bytes: data = json.dumps( obj=data, indent=None, - separators=(',', ':'), + separators=(",", ":"), default=self._default if self._custom_types else None, ).encode("utf8") if len(data) > self._gzip_threshold: @@ -252,14 +249,11 @@ def serialize(self, data: dict) -> bytes: return data def deserialize(self, data: bytes) -> dict: - if data[:1] == b'\x78': + if data[:1] == b"\x78": # First byte of zlib data is practically almost always x78 _log.debug(f"JsonSerDe.deserialize: detected zlib compressed data") data = zlib.decompress(data) - return json.loads( - s=data.decode("utf8"), - object_hook=self._object_hook if self._decode_map else None - ) + return json.loads(s=data.decode("utf8"), object_hook=self._object_hook if self._decode_map else None) # Global JSON SerDe instance @@ -358,18 +352,19 @@ class ZkMemoizer(Memoizer): count = zk_cache.get_or_call("count", callback=calculate_count) """ + DEFAULT_TTL = 5 * 60 DEFAULT_ZK_TIMEOUT = 5 _serde = json_serde def __init__( - self, - client: KazooClient, - path_prefix: str, - namespace: str = DEFAULT_NAMESPACE, - default_ttl: Optional[float] = None, - zk_timeout: Optional[float] = None, + self, + client: KazooClient, + path_prefix: str, + namespace: str = DEFAULT_NAMESPACE, + default_ttl: Optional[float] = None, + zk_timeout: Optional[float] = None, ): super().__init__(namespace=namespace) self._client = client @@ -523,10 +518,13 @@ def get_memoizer(memoizer_type: str, memoizer_conf: dict) -> Memoizer: zk_timeout=memoizer_conf.get("zk_timeout"), ) elif memoizer_type == "chained": - return ChainedMemoizer([ - get_memoizer(memoizer_type=part["type"], memoizer_conf=part["config"]) - for part in memoizer_conf["parts"] - ], namespace=namespace) + return ChainedMemoizer( + [ + get_memoizer(memoizer_type=part["type"], memoizer_conf=part["config"]) + for part in memoizer_conf["parts"] + ], + namespace=namespace, + ) else: raise ValueError(memoizer_type) diff --git a/src/openeo_aggregator/config.py b/src/openeo_aggregator/config.py index df0b1d5d..54bbdc38 100644 --- a/src/openeo_aggregator/config.py +++ b/src/openeo_aggregator/config.py @@ -37,8 +37,6 @@ class AggregatorConfig(dict): config_source = dict_item() - - @attrs.frozen(kw_only=True) class AggregatorBackendConfig(OpenEoBackendConfig): diff --git a/src/openeo_aggregator/connection.py b/src/openeo_aggregator/connection.py index f2e1d6a2..1a1b8541 100644 --- a/src/openeo_aggregator/connection.py +++ b/src/openeo_aggregator/connection.py @@ -50,6 +50,7 @@ BackendId = str + class LockedAuthException(InternalException): def __init__(self): super().__init__(message="Setting auth while locked.") @@ -211,11 +212,11 @@ class MultiBackendConnection: """ Collection of multiple connections to different backends """ + # TODO: API version management: just do single/fixed-version federation, or also handle version discovery? # TODO: keep track of (recent) backend failures, e.g. to automatically blacklist a backend # TODO: synchronized backend connection caching/flushing across gunicorn workers, for better consistency? - _TIMEOUT = 5 def __init__( @@ -277,8 +278,7 @@ def get_connections(self) -> List[BackendConnection]: for con in self._connections_cache.connections: con.invalidate() self._connections_cache = _ConnectionsCache( - expiry=now + self._connections_cache_ttl, - connections=list(self._get_connections(skip_failures=True)) + expiry=now + self._connections_cache_ttl, connections=list(self._get_connections(skip_failures=True)) ) new_bids = [c.id for c in self._connections_cache.connections] _log.debug( @@ -328,9 +328,7 @@ def _get_api_versions(self) -> List[str]: def get_api_versions(self) -> Set[ComparableVersion]: """Get set of API versions reported by backends""" - versions = self._memoizer.get_or_call( - key="api_versions", callback=self._get_api_versions - ) + versions = self._memoizer.get_or_call(key="api_versions", callback=self._get_api_versions) versions = set(ComparableVersion(v) for v in versions) return versions @@ -344,9 +342,7 @@ def api_version_maximum(self) -> ComparableVersion: """Get the highest API version of all back-ends""" return max(self.get_api_versions()) - def map( - self, callback: Callable[[BackendConnection], Any] - ) -> Iterator[Tuple[str, Any]]: + def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[str, Any]]: """ Query each backend connection with given callable and return results as iterator @@ -438,10 +434,8 @@ def do_request( return ParallelResponse(successes=successes, failures=failures) - def streaming_flask_response( - backend_response: requests.Response, - chunk_size: int = STREAM_CHUNK_SIZE_DEFAULT + backend_response: requests.Response, chunk_size: int = STREAM_CHUNK_SIZE_DEFAULT ) -> flask.Response: """ Convert a `requests.Response` coming from a backend @@ -450,10 +444,7 @@ def streaming_flask_response( :param backend_response: `requests.Response` object (possibly created with "stream" option enabled) :param chunk_size: chunk size to use for streaming """ - headers = [ - (k, v) for (k, v) in backend_response.headers.items() - if k.lower() in ["content-type"] - ] + headers = [(k, v) for (k, v) in backend_response.headers.items() if k.lower() in ["content-type"]] return flask.Response( # Streaming response through `iter_content` generator (https://flask.palletsprojects.com/en/2.0.x/patterns/streaming/) response=backend_response.iter_content(chunk_size=chunk_size), diff --git a/src/openeo_aggregator/egi.py b/src/openeo_aggregator/egi.py index f2e82d8c..a372f589 100644 --- a/src/openeo_aggregator/egi.py +++ b/src/openeo_aggregator/egi.py @@ -23,12 +23,10 @@ \#(?P[a-z0-9._-]+) $ """, - flags=re.VERBOSE | re.IGNORECASE + flags=re.VERBOSE | re.IGNORECASE, ) -Entitlement = namedtuple( - "Entitlement", ["namespace", "vo", "group", "role", "authority"] -) +Entitlement = namedtuple("Entitlement", ["namespace", "vo", "group", "role", "authority"]) @functools.lru_cache(maxsize=100) @@ -56,8 +54,7 @@ class UserRole: def __init__(self, title: str): self._title = title self._id = "".join( - w.title() if w.islower() else w - for w in self._title.replace("-", " ").replace("_", " ").split() + w.title() if w.islower() else w for w in self._title.replace("-", " ").replace("_", " ").split() ) self._normalized = self.normalize_role(self._title) @@ -81,19 +78,13 @@ def entitlement_match(self, entitlement: str): ) - - class OpeneoPlatformUserRoles: def __init__(self, roles: List[UserRole]): self.roles = roles def extract_roles(self, entitlements: List[str]) -> List[UserRole]: """Extract user roles based on list of eduperson_entitlement values""" - return [ - role - for role in self.roles - if any(role.entitlement_match(e) for e in entitlements) - ] + return [role for role in self.roles if any(role.entitlement_match(e) for e in entitlements)] # Standardized roles in openEO Platform EGI Virtual Organisation diff --git a/src/openeo_aggregator/errors.py b/src/openeo_aggregator/errors.py index 3feb2249..59c9cb5e 100644 --- a/src/openeo_aggregator/errors.py +++ b/src/openeo_aggregator/errors.py @@ -3,6 +3,6 @@ class BackendLookupFailureException(OpenEOApiException): status_code = 400 - code = 'BackendLookupFailure' - message = 'Failed to determine back-end to use.' + code = "BackendLookupFailure" + message = "Failed to determine back-end to use." _description = None diff --git a/src/openeo_aggregator/metadata/merging.py b/src/openeo_aggregator/metadata/merging.py index e0de4ef9..aa4e62fb 100644 --- a/src/openeo_aggregator/metadata/merging.py +++ b/src/openeo_aggregator/metadata/merging.py @@ -63,8 +63,16 @@ def merge_collection_metadata( if full_metadata: for backend_id, collection in by_backend.items(): - for required_field in ["stac_version", "id", "description", "license", "extent", "links", "cube:dimensions", - "summaries"]: + for required_field in [ + "stac_version", + "id", + "description", + "license", + "extent", + "links", + "cube:dimensions", + "summaries", + ]: if required_field not in collection: report( f"Missing {required_field} in collection metadata.", @@ -93,7 +101,7 @@ def merge_collection_metadata( # - `crs` is required by OGC API: https://docs.opengeospatial.org/is/18-058/18-058.html#_crs_identifier_list # - `sci:doi` and related are defined at https://github.com/stac-extensions/scientific for field in getter.available_keys(["stac_extensions", "keywords", "providers", "sci:publications"]): - result[field] = getter.concat(field, skip_duplicates = True) + result[field] = getter.concat(field, skip_duplicates=True) for field in getter.available_keys(["deprecated"]): result[field] = all(getter.get(field)) for field in getter.available_keys(["crs", "sci:citation", "sci:doi"]): @@ -104,9 +112,7 @@ def merge_collection_metadata( for backend_id, collection in by_backend.items(): try: if "summaries" in collection: - summaries_by_backend[backend_id] = StacSummaries.from_dict( - collection.get("summaries") - ) + summaries_by_backend[backend_id] = StacSummaries.from_dict(collection.get("summaries")) except Exception as e: report("Failed to parse summaries", collection_id=cid, backend_id=backend_id, exception=e) result["summaries"] = StacSummaries.merge_all( @@ -168,11 +174,12 @@ def merge_collection_metadata( t_extent = list(cube_dim_getter.select(t_dim).get("extent")) try: # TODO: Is multidict getter with id required? - t_starts = [e[0] for e in t_extent if e[0] and e[0] != 'None'] - t_ends = [e[1] for e in t_extent if e[1] and e[1] != 'None'] + t_starts = [e[0] for e in t_extent if e[0] and e[0] != "None"] + t_ends = [e[1] for e in t_extent if e[1] and e[1] != "None"] result["cube:dimensions"][t_dim]["extent"] = [ min(rfc3339.normalize(t) for t in t_starts) if t_starts else None, - max(rfc3339.normalize(t) for t in t_ends) if t_ends else None] + max(rfc3339.normalize(t) for t in t_ends) if t_ends else None, + ] except Exception as e: report( f"Failed to merge cube:dimensions.{t_dim}.extent: {e!r}, actual: {t_extent}", @@ -218,9 +225,7 @@ def merge_collection_metadata( def single_backend_collection_post_processing(metadata: dict, *, backend_id: str): """In-place post-processing of a single backend collection""" - if not deep_get( - metadata, "summaries", STAC_PROPERTY_FEDERATION_BACKENDS, default=None - ): + if not deep_get(metadata, "summaries", STAC_PROPERTY_FEDERATION_BACKENDS, default=None): metadata.setdefault("summaries", {}) metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS] = [backend_id] else: @@ -235,9 +240,7 @@ def set_if_non_empty(d: dict, key: str, value: Any): d[key] = value -def json_diff( - a: Any, b: Any, a_name: str = "", b_name: str = "", context: int = 3 -) -> List[str]: +def json_diff(a: Any, b: Any, a_name: str = "", b_name: str = "", context: int = 3) -> List[str]: """ Generate unified diff of JSON serialization of given objects :return: List of diff lines @@ -273,9 +276,7 @@ class ProcessMetadataMerger: def __init__(self, report: Callable = DEFAULT_REPORTER.report): self.report = report - def merge_processes_metadata( - self, processes_per_backend: Dict[str, Dict[str, dict]] - ) -> Dict[str, dict]: + def merge_processes_metadata(self, processes_per_backend: Dict[str, Dict[str, dict]]) -> Dict[str, dict]: """ Merge process metadata listings from multiple back-ends into a single process listing. @@ -294,9 +295,7 @@ def merge_processes_metadata( try: merged[process_id] = self.merge_process_metadata(by_backend) except Exception as e: - self.report( - f"Failed to merge process metadata: {e!r}", process_id=process_id - ) + self.report(f"Failed to merge process metadata: {e!r}", process_id=process_id) return merged def merge_process_metadata(self, by_backend: Dict[str, dict]) -> dict: @@ -322,21 +321,13 @@ def merge_process_metadata(self, by_backend: Dict[str, dict]) -> dict: } set_if_non_empty(merged, "summary", getter.first("summary", default=None)) - merged["parameters"] = self._merge_process_parameters( - by_backend=by_backend, process_id=process_id - ) + merged["parameters"] = self._merge_process_parameters(by_backend=by_backend, process_id=process_id) # Return schema - merged["returns"] = self._merge_process_returns( - by_backend=by_backend, process_id=process_id - ) + merged["returns"] = self._merge_process_returns(by_backend=by_backend, process_id=process_id) - set_if_non_empty( - merged, "exceptions", self._merge_process_exceptions(by_backend=by_backend) - ) - set_if_non_empty( - merged, "categories", self._merge_process_categories(by_backend=by_backend) - ) + set_if_non_empty(merged, "exceptions", self._merge_process_exceptions(by_backend=by_backend)) + set_if_non_empty(merged, "categories", self._merge_process_categories(by_backend=by_backend)) merged["deprecated"] = any(getter.get("deprecated")) merged["experimental"] = any(getter.get("experimental")) @@ -345,9 +336,7 @@ def merge_process_metadata(self, by_backend: Dict[str, dict]) -> dict: return merged - def _get_parameters_by_name( - self, parameters: List[dict], backend_id: str, process_id: str - ) -> Dict[str, dict]: + def _get_parameters_by_name(self, parameters: List[dict], backend_id: str, process_id: str) -> Dict[str, dict]: """Build dictionary of parameter metadata, keyed on name.""" names = {} try: @@ -373,9 +362,7 @@ def _get_parameters_by_name( return names - def _merge_process_parameters( - self, by_backend: Dict[str, dict], process_id: str - ) -> List[dict]: + def _merge_process_parameters(self, by_backend: Dict[str, dict], process_id: str) -> List[dict]: # Pick first non-empty parameter listing as merge result # TODO: real merge instead of taking first? merged = [] @@ -421,13 +408,9 @@ def _merge_process_parameters( normalizer = ProcessParameterNormalizer( strip_description=True, add_optionals=True, - report=functools.partial( - self.report, backend_id=backend_id, process_id=process_id - ), - ) - merged_param = normalizer.normalize_parameter( - merged_params_by_name[name] + report=functools.partial(self.report, backend_id=backend_id, process_id=process_id), ) + merged_param = normalizer.normalize_parameter(merged_params_by_name[name]) other_param = normalizer.normalize_parameter(params_by_name[name]) for field in merged_param.keys(): merged_value = merged_param[field] @@ -449,9 +432,7 @@ def _merge_process_parameters( return merged - def _merge_process_returns( - self, by_backend: Dict[str, dict], process_id: str - ) -> dict: + def _merge_process_returns(self, by_backend: Dict[str, dict], process_id: str) -> dict: """ Merge `returns` metadata :param by_backend: {backend_id: process_metadata} @@ -471,9 +452,7 @@ def _merge_process_returns( process_id=process_id, merged=merged, value=other_returns, - diff=json_diff( - merged, other_returns, a_name="merged", b_name=backend_id - ), + diff=json_diff(merged, other_returns, a_name="merged", b_name=backend_id), ) return merged @@ -534,9 +513,7 @@ def normalize_parameter(self, param: dict) -> dict: """Normalize a parameter metadata dict""" for required in ["name", "schema", "description"]: if required not in param: - self.report( - f"Missing required field {required!r} in parameter metadata {param!r}" - ) + self.report(f"Missing required field {required!r} in parameter metadata {param!r}") normalized = { "name": param.get("name", "n/a"), "schema": param.get("schema", {}), @@ -578,12 +555,7 @@ def normalize_recursively(self, x: Any) -> Any: and x.get("subtype") == "process-graph" and isinstance(x.get("parameters"), list) ): - return { - k: self.normalize_parameters(parameters=v) - if k == "parameters" - else v - for k, v in x.items() - } + return {k: self.normalize_parameters(parameters=v) if k == "parameters" else v for k, v in x.items()} else: return {k: self.normalize_recursively(v) for k, v in x.items()} elif isinstance(x, list): diff --git a/src/openeo_aggregator/metadata/models/cube_dimension.py b/src/openeo_aggregator/metadata/models/cube_dimension.py index 34f9b5b1..67bf8728 100644 --- a/src/openeo_aggregator/metadata/models/cube_dimension.py +++ b/src/openeo_aggregator/metadata/models/cube_dimension.py @@ -62,13 +62,15 @@ def _set_extent_from_dict(self, d, is_required, is_open, is_spatial, identifier) Spatial extents can only contain numbers. Non spatial extents can only contain strings. """ - invalid_error = ValueError("Could not parse CubeDimension object, extent for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, extent for {} is invalid. " "actual: {}".format(identifier, d) + ) # Include identifier extent = d.get("extent", UNSET) if is_required and (extent is UNSET or extent is None): - raise ValueError("Could not parse CubeDimension object, extent for {} is required. " - "actual: {}".format(identifier, d)) + raise ValueError( + "Could not parse CubeDimension object, extent for {} is required. " "actual: {}".format(identifier, d) + ) if extent is not UNSET: if not isinstance(extent, list): raise invalid_error @@ -85,8 +87,9 @@ def _set_extent_from_dict(self, d, is_required, is_open, is_spatial, identifier) self.extent = extent def _set_values_from_dict(self, d, types, identifier): - invalid_error = ValueError("Could not parse CubeDimension object, values for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, values for {} is invalid. " "actual: {}".format(identifier, d) + ) values = d.get("values", UNSET) if values is not UNSET: if not isinstance(values, list): @@ -97,8 +100,9 @@ def _set_values_from_dict(self, d, types, identifier): self.values = values def _set_step_from_dict(self, d, types, identifier): - invalid_error = ValueError("Could not parse CubeDimension object, step for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, step for {} is invalid. " "actual: {}".format(identifier, d) + ) step = d.get("step", UNSET) if step is not UNSET: if not isinstance(step, tuple(types)): @@ -106,8 +110,10 @@ def _set_step_from_dict(self, d, types, identifier): self.step = step def _set_reference_system_from_dict(self, d, types, identifier): - invalid_error = ValueError("Could not parse CubeDimension object, reference_system for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, reference_system for {} is invalid. " + "actual: {}".format(identifier, d) + ) reference_system = d.get("reference_system", UNSET) if reference_system is not UNSET: if not isinstance(reference_system, tuple(types)): @@ -115,8 +121,9 @@ def _set_reference_system_from_dict(self, d, types, identifier): self.reference_system = reference_system def _set_description_from_dict(self, d, identifier): - invalid_error = ValueError("Could not parse CubeDimension object, description for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, description for {} is invalid. " "actual: {}".format(identifier, d) + ) _description = d.get("description", UNSET) if _description is not UNSET: if not isinstance(_description, str): @@ -124,12 +131,14 @@ def _set_description_from_dict(self, d, identifier): self.description = _description def _set_type_from_dict(self, d, expected_type, identifier): - invalid_error = ValueError("Could not parse CubeDimension object, type for {} is invalid. " - "actual: {}".format(identifier, d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, type for {} is invalid. " "actual: {}".format(identifier, d) + ) _type = d.get("type", UNSET) if _type is UNSET: - raise ValueError("Could not parse CubeDimension object, type for {} is required. " - "actual: {}".format(identifier, d)) + raise ValueError( + "Could not parse CubeDimension object, type for {} is required. " "actual: {}".format(identifier, d) + ) if not isinstance(_type, str): raise invalid_error if expected_type == DimensionType.OTHER.value: @@ -139,18 +148,23 @@ def _set_type_from_dict(self, d, expected_type, identifier): self.type = DimensionType(_type) return if _type != expected_type: - raise ValueError("Could not parse CubeDimension object, expected type for {} is {}, " - "actual: {}".format(identifier, expected_type, d)) + raise ValueError( + "Could not parse CubeDimension object, expected type for {} is {}, " + "actual: {}".format(identifier, expected_type, d) + ) self.type = DimensionType(_type) def _set_horizontal_spatial_dimension_from_dict(self, src_dict: Dict[str, Any]): d = src_dict.copy() - invalid_error = ValueError("Could not parse CubeDimension object, horizontal dimension is invalid. " - "actual: {}".format(d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, horizontal dimension is invalid. " "actual: {}".format(d) + ) for required_field in ["type", "extent", "axis"]: if required_field not in d: - raise ValueError("Could not parse CubeDimension object, required field {} for horizontal dimension " - "is missing. actual: {}".format(required_field, d)) + raise ValueError( + "Could not parse CubeDimension object, required field {} for horizontal dimension " + "is missing. actual: {}".format(required_field, d) + ) _axis = d.get("axis", UNSET) if _axis not in ["x", "y"]: raise invalid_error @@ -160,18 +174,35 @@ def _set_horizontal_spatial_dimension_from_dict(self, src_dict: Dict[str, Any]): self._set_type_from_dict(d, DimensionType.SPATIAL.value, identifier) self._set_description_from_dict(d, identifier) self._set_extent_from_dict(d, is_required=True, is_open=False, is_spatial=True, identifier=identifier) - self._set_values_from_dict(d, types=(float, int,), identifier=identifier) - self._set_step_from_dict(d, types=(float, int,), identifier=identifier) + self._set_values_from_dict( + d, + types=( + float, + int, + ), + identifier=identifier, + ) + self._set_step_from_dict( + d, + types=( + float, + int, + ), + identifier=identifier, + ) self._set_reference_system_from_dict(d, types=(str, float, int, dict), identifier=identifier) def _set_vertical_spatial_dimension_from_dict(self, src_dict: Dict[str, Any]): d = src_dict.copy() - invalid_error = ValueError("Could not parse CubeDimension object, vertical dimension is invalid. " - "actual: {}".format(d)) + invalid_error = ValueError( + "Could not parse CubeDimension object, vertical dimension is invalid. " "actual: {}".format(d) + ) for required_field in ["type", "axis"]: if required_field not in d: - raise ValueError("Could not parse CubeDimension object, required field {} for vertical dimension " - "is missing. actual: {}".format(required_field, d)) + raise ValueError( + "Could not parse CubeDimension object, required field {} for vertical dimension " + "is missing. actual: {}".format(required_field, d) + ) _axis = d.get("axis", UNSET) if _axis not in ["z"]: raise invalid_error @@ -180,7 +211,7 @@ def _set_vertical_spatial_dimension_from_dict(self, src_dict: Dict[str, Any]): identifier = "vertical dimension" self._set_type_from_dict(d, DimensionType.SPATIAL.value, identifier) self._set_description_from_dict(d, identifier) - self._set_extent_from_dict(d, is_required =False, is_open =True, is_spatial =True, identifier=identifier) + self._set_extent_from_dict(d, is_required=False, is_open=True, is_spatial=True, identifier=identifier) self._set_values_from_dict(d, types=(float, int, str), identifier=identifier) self._set_step_from_dict(d, types=(float, int, type(None)), identifier=identifier) self._set_reference_system_from_dict(d, types=(str, float, int, dict), identifier=identifier) @@ -189,12 +220,14 @@ def _set_temporal_dimension_from_dict(self, src_dict: Dict[str, Any]): d = src_dict.copy() for required_field in ["type", "extent"]: if required_field not in d: - raise ValueError("Could not parse CubeDimension object, required field {} for temporal dimension " - "is missing. actual: {}".format(required_field, d)) + raise ValueError( + "Could not parse CubeDimension object, required field {} for temporal dimension " + "is missing. actual: {}".format(required_field, d) + ) identifier = "temporal dimension" self._set_type_from_dict(d, DimensionType.TEMPORAL.value, identifier) self._set_description_from_dict(d, identifier) - self._set_extent_from_dict(d, is_required =True, is_open =True, is_spatial =False, identifier=identifier) + self._set_extent_from_dict(d, is_required=True, is_open=True, is_spatial=False, identifier=identifier) self._set_values_from_dict(d, types=(str,), identifier=identifier) self._set_step_from_dict(d, types=(str, type(None)), identifier=identifier) @@ -202,12 +235,14 @@ def _set_other_dimension_from_dict(self, src_dict: Dict[str, Any]): d = src_dict.copy() for required_field in ["type"]: if required_field not in d: - raise ValueError("Could not parse CubeDimension object, required field {} for other dimension " - "is missing. actual: {}".format(required_field, d)) + raise ValueError( + "Could not parse CubeDimension object, required field {} for other dimension " + "is missing. actual: {}".format(required_field, d) + ) identifier = "other dimension" self._set_type_from_dict(d, DimensionType.OTHER.value, identifier) self._set_description_from_dict(d, identifier) - self._set_extent_from_dict(d, is_required =False, is_open =True, is_spatial =True, identifier=identifier) + self._set_extent_from_dict(d, is_required=False, is_open=True, is_spatial=True, identifier=identifier) self._set_values_from_dict(d, types=(float, int, str), identifier=identifier) self._set_step_from_dict(d, types=(float, int, type(None)), identifier=identifier) self._set_reference_system_from_dict(d, types=(str), identifier=identifier) diff --git a/src/openeo_aggregator/metadata/models/cube_dimensions.py b/src/openeo_aggregator/metadata/models/cube_dimensions.py index 61b865c1..86ab0345 100644 --- a/src/openeo_aggregator/metadata/models/cube_dimensions.py +++ b/src/openeo_aggregator/metadata/models/cube_dimensions.py @@ -49,8 +49,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: try: additional_property = CubeDimension.from_dict(prop_dict) except Exception as e: - raise TypeError("Error parsing '%s' of CubeDimensions: %s" % - (prop_name, e)) + raise TypeError("Error parsing '%s' of CubeDimensions: %s" % (prop_name, e)) additional_properties[prop_name] = additional_property collection_stac_collection_cube_dimensions.dimensions = additional_properties diff --git a/src/openeo_aggregator/metadata/models/stac_eo.py b/src/openeo_aggregator/metadata/models/stac_eo.py index 0731e22b..6287f30b 100644 --- a/src/openeo_aggregator/metadata/models/stac_eo.py +++ b/src/openeo_aggregator/metadata/models/stac_eo.py @@ -92,19 +92,11 @@ class EoBand: @staticmethod def from_dict(obj: Any) -> "EoBand": assert isinstance(obj, dict) - center_wavelength = from_union( - [from_float, from_none], obj.get("center_wavelength") - ) - common_name = from_union( - [CommonNameOfTheBand, from_none], obj.get("common_name") - ) - full_width_half_max = from_union( - [from_float, from_none], obj.get("full_width_half_max") - ) + center_wavelength = from_union([from_float, from_none], obj.get("center_wavelength")) + common_name = from_union([CommonNameOfTheBand, from_none], obj.get("common_name")) + full_width_half_max = from_union([from_float, from_none], obj.get("full_width_half_max")) name = from_union([from_str, from_none], obj.get("name")) - solar_illumination = from_union( - [from_float, from_none], obj.get("solar_illumination") - ) + solar_illumination = from_union([from_float, from_none], obj.get("solar_illumination")) return EoBand( center_wavelength, common_name, @@ -116,19 +108,13 @@ def from_dict(obj: Any) -> "EoBand": def to_dict(self) -> dict: return dict_no_none( { - "center_wavelength": from_union( - [to_float, from_none], self.center_wavelength - ), + "center_wavelength": from_union([to_float, from_none], self.center_wavelength), "common_name": from_union( [lambda x: to_enum(CommonNameOfTheBand, x), from_none], self.common_name, ), - "full_width_half_max": from_union( - [to_float, from_none], self.full_width_half_max - ), + "full_width_half_max": from_union([to_float, from_none], self.full_width_half_max), "name": from_union([from_str, from_none], self.name), - "solar_illumination": from_union( - [to_float, from_none], self.solar_illumination - ), + "solar_illumination": from_union([to_float, from_none], self.solar_illumination), } ) diff --git a/src/openeo_aggregator/metadata/models/stac_summaries.py b/src/openeo_aggregator/metadata/models/stac_summaries.py index 4113d81d..aecc0ad2 100644 --- a/src/openeo_aggregator/metadata/models/stac_summaries.py +++ b/src/openeo_aggregator/metadata/models/stac_summaries.py @@ -81,13 +81,11 @@ def _parse_additional_property(data: object) -> Union[List[Any], Statistics, Non except: # noqa: E722 pass if not isinstance(data, dict): - raise TypeError("Expected dict for '%s' of StacSummaries, actual %s" % - (prop_name, type(data))) + raise TypeError("Expected dict for '%s' of StacSummaries, actual %s" % (prop_name, type(data))) try: componentsschemas_stac_summaries_collection_properties_type_1 = Statistics.from_dict(data) except ValueError as e: - raise TypeError("Error parsing '%s' of StacSummaries: %s" % - (prop_name, e)) + raise TypeError("Error parsing '%s' of StacSummaries: %s" % (prop_name, e)) return componentsschemas_stac_summaries_collection_properties_type_1 additional_property = _parse_additional_property(prop_dict) @@ -127,9 +125,7 @@ def merge_all( :return: Merged summaries """ - by_backend = { - k: v.additional_properties for k, v in summaries_by_backend.items() - } + by_backend = {k: v.additional_properties for k, v in summaries_by_backend.items()} # Calculate the unique summary names. unique_summary_names: Set[str] = functools.reduce( lambda a, b: a.union(b), (d.keys() for d in by_backend.values()), set() @@ -155,12 +151,7 @@ def merge_all( for collection_summaries in by_backend.values(): try: if summary_name in collection_summaries: - eo_bands_lists.append( - [ - EoBand.from_dict(b) - for b in collection_summaries.get(summary_name) - ] - ) + eo_bands_lists.append([EoBand.from_dict(b) for b in collection_summaries.get(summary_name)]) except Exception as e: report( f"Failed to parse summary {summary_name!r}: {e!r}", @@ -174,11 +165,7 @@ def merge_all( f"Empty prefix for {summary_name!r}, falling back to first back-end's {summary_name!r}", collection_id=collection_id, ) - eo_bands = next( - d.get(summary_name) - for d in by_backend.values() - if summary_name in d - ) + eo_bands = next(d.get(summary_name) for d in by_backend.values() if summary_name in d) merged_addition_properties[summary_name] = eo_bands else: report(f"Unhandled merging of summary {summary_name!r}", collection_id=collection_id) diff --git a/src/openeo_aggregator/metadata/utils.py b/src/openeo_aggregator/metadata/utils.py index 7cf0f70e..2b3a1d66 100644 --- a/src/openeo_aggregator/metadata/utils.py +++ b/src/openeo_aggregator/metadata/utils.py @@ -6,6 +6,8 @@ class Unset: def __bool__(self) -> bool: return False + + UNSET: Unset = Unset() diff --git a/src/openeo_aggregator/metadata/validator.py b/src/openeo_aggregator/metadata/validator.py index e8dc04d5..76537799 100644 --- a/src/openeo_aggregator/metadata/validator.py +++ b/src/openeo_aggregator/metadata/validator.py @@ -34,18 +34,13 @@ def main(): target_group = parser.add_argument_group( "Target", "Which checks to run (if none specified, all checks will be run)." ) - target_group.add_argument( - "-c", "--collections", action="store_true", help="Check collection metadata" - ) - target_group.add_argument( - "-p", "--processes", action="store_true", help="Check process metadata" - ) + target_group.add_argument("-c", "--collections", action="store_true", help="Check collection metadata") + target_group.add_argument("-p", "--processes", action="store_true", help="Check process metadata") args = parser.parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.WARNING) _log.info(f"{args=}") - # Determine backend ids/urls aggregator_backends = get_backend_config().aggregator_backends backend_ids = args.backends or list(aggregator_backends.keys()) @@ -68,9 +63,7 @@ def main(): compare_get_collections(backends_for_collection_id) # Compare /collections/{collection_id} for collection_id in sorted(set(backends_for_collection_id.keys())): - compare_get_collection_by_id( - backend_urls=backend_urls, collection_id=collection_id - ) + compare_get_collection_by_id(backend_urls=backend_urls, collection_id=collection_id) if check_processes: # Compare /processes compare_get_processes(backend_urls=backend_urls) @@ -84,10 +77,8 @@ def compare_get_collections(backends_for_collection_id): getter = MultiDictGetter(by_backend.values()) cid = getter.single_value_for("id") if cid != collection_id: - reporter.report("Collection id in metadata does not match id in url", collection_id = cid) - merge_collection_metadata( - by_backend, full_metadata=False, report=reporter.report - ) + reporter.report("Collection id in metadata does not match id in url", collection_id=cid) + merge_collection_metadata(by_backend, full_metadata=False, report=reporter.report) def compare_get_collection_by_id(backend_urls, collection_id): @@ -103,7 +94,7 @@ def compare_get_collection_by_id(backend_urls, collection_id): if "id" in collection: by_backend[url] = collection else: - reporter.report("Backend returned invalid collection", backend_id = url, collection_id = collection_id) + reporter.report("Backend returned invalid collection", backend_id=url, collection_id=collection_id) merge_collection_metadata(by_backend, full_metadata=True, report=reporter.report) @@ -116,9 +107,7 @@ def compare_get_processes(backend_urls): processes = r.json().get("processes", []) processes_per_backend[url] = {p["id"]: p for p in processes} reporter = MarkDownReporter() - ProcessMetadataMerger(report=reporter.report).merge_processes_metadata( - processes_per_backend - ) + ProcessMetadataMerger(report=reporter.report).merge_processes_metadata(processes_per_backend) def get_all_collection_ids(backend_urls) -> Dict[str, Dict[str, Dict]]: diff --git a/src/openeo_aggregator/partitionedjobs/__init__.py b/src/openeo_aggregator/partitionedjobs/__init__.py index 7c099506..fbabb0ce 100644 --- a/src/openeo_aggregator/partitionedjobs/__init__.py +++ b/src/openeo_aggregator/partitionedjobs/__init__.py @@ -11,6 +11,7 @@ class PartitionedJobFailure(OpenEOApiException): class SubJob(NamedTuple): """A part of a partitioned job, target at a particular, single back-end.""" + # Process graph of the subjob (derived in some way from original parent process graph) process_graph: FlatPG # Id of target backend (or None if there is no dedicated backend) @@ -19,6 +20,7 @@ class SubJob(NamedTuple): class PartitionedJob(NamedTuple): """A large or multi-back-end job that is split in several sub jobs""" + # Original process graph process: PGWithMetadata metadata: dict @@ -28,9 +30,7 @@ class PartitionedJob(NamedTuple): dependencies: Dict[str, Sequence[str]] = {} @staticmethod - def to_subjobs_dict( - subjobs: Union[Sequence[SubJob], Dict[Any, SubJob]] - ) -> Dict[str, SubJob]: + def to_subjobs_dict(subjobs: Union[Sequence[SubJob], Dict[Any, SubJob]]) -> Dict[str, SubJob]: """Helper to convert a collection of SubJobs to a dictionary""" # TODO: hide this logic in a setter or __init__ (e.g. when outgrowing the constraints of typing.NamedTuple) if isinstance(subjobs, Sequence): diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 1a33de0b..9c973004 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -67,9 +67,7 @@ class CrossBackendSplitter(AbstractJobSplitter): """ - def __init__( - self, backend_for_collection: Callable[[str], str], always_split: bool = False - ): + def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False): """ :param backend_for_collection: callable that determines backend id for given collection id :param always_split: split all load_collections, also when on same backend @@ -183,14 +181,10 @@ def _resolve_dependencies(process_graph: FlatPG, batch_jobs: Dict[str, BatchJob] """ result = dict() for node_id, node in process_graph.items(): - if node["process_id"] == "load_result" and node["arguments"]["id"].startswith( - _LOAD_RESULT_PLACEHOLDER - ): + if node["process_id"] == "load_result" and node["arguments"]["id"].startswith(_LOAD_RESULT_PLACEHOLDER): dep_id = node["arguments"]["id"].partition(_LOAD_RESULT_PLACEHOLDER)[-1] batch_job = batch_jobs[dep_id] - _log.info( - f"resolve_dependencies: replace placeholder {dep_id!r} with concrete {batch_job.job_id!r}" - ) + _log.info(f"resolve_dependencies: replace placeholder {dep_id!r} with concrete {batch_job.job_id!r}") try: # Try to get "canonical" result URL (signed URL) links = batch_job.get_results().get_metadata()["links"] @@ -236,9 +230,7 @@ def _loop(): yield i -def run_partitioned_job( - pjob: PartitionedJob, connection: openeo.Connection, fail_fast: bool = True -) -> dict: +def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fail_fast: bool = True) -> dict: """ Run partitioned job (probably with dependencies between subjobs) with an active polling loop for tracking and scheduling the subjobs @@ -275,16 +267,11 @@ def run_partitioned_job( if states[subjob_id] == SUBJOB_STATES.WAITING: dep_states = set(states[dep] for dep in dependencies.get(subjob_id, [])) _log.info(f"Dependency states for {subjob_id=!r}: {dep_states}") - if ( - SUBJOB_STATES.ERROR in dep_states - or SUBJOB_STATES.CANCELED in dep_states - ): + if SUBJOB_STATES.ERROR in dep_states or SUBJOB_STATES.CANCELED in dep_states: _log.info(f"Dependency failure: canceling {subjob_id=!r}") states[subjob_id] = SUBJOB_STATES.CANCELED elif all(s == SUBJOB_STATES.FINISHED for s in dep_states): - _log.info( - f"No unfulfilled dependencies: ready to start {subjob_id=!r}" - ) + _log.info(f"No unfulfilled dependencies: ready to start {subjob_id=!r}") states[subjob_id] = SUBJOB_STATES.READY # Handle job (start, poll status, ...) @@ -292,9 +279,7 @@ def run_partitioned_job( try: process_graph = _resolve_dependencies(subjob.process_graph, batch_jobs=batch_jobs) - _log.info( - f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}" - ) + _log.info(f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}") # Create batch_job = connection.create_job( process_graph=process_graph, @@ -307,9 +292,7 @@ def run_partitioned_job( # Start batch_job.start_job() states[subjob_id] = SUBJOB_STATES.RUNNING - _log.info( - f"Started batch job {batch_job.job_id!r} for subjob {subjob_id!r}" - ) + _log.info(f"Started batch job {batch_job.job_id!r} for subjob {subjob_id!r}") except Exception as e: if fail_fast: raise diff --git a/src/openeo_aggregator/partitionedjobs/splitting.py b/src/openeo_aggregator/partitionedjobs/splitting.py index 10a2f3b4..d8f2bdca 100644 --- a/src/openeo_aggregator/partitionedjobs/splitting.py +++ b/src/openeo_aggregator/partitionedjobs/splitting.py @@ -48,9 +48,7 @@ class AbstractJobSplitter(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def split( - self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None - ) -> PartitionedJob: + def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: # TODO: how to express combination/aggregation of multiple subjob results as a final result? ... @@ -65,9 +63,7 @@ class FlimsySplitter(AbstractJobSplitter): def __init__(self, processing: "AggregatorProcessing"): self.processing = processing - def split( - self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None - ) -> PartitionedJob: + def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: process_graph = process["process_graph"] backend_id = self.processing.get_backend_for_process_graph(process_graph=process_graph, api_version="TODO") process_graph = self.processing.preprocess_process_graph(process_graph, backend_id=backend_id) @@ -85,6 +81,7 @@ class TileGrid(typing.NamedTuple): """ Specification of a tile grid, parsed from a string, e.g. 'wgs84-1degree', 'utm-100km', 'utm-20km', 'utm-10km'. """ + crs_type: str size: int unit: str @@ -113,9 +110,7 @@ def get_tiles(self, bbox: BoundingBox, max_tiles=MAX_TILES) -> List[BoundingBox] raise JobSplittingFailure(f"Unsupported tile grid {self.crs_type}") # Bounding box (in tiling CRS) to cover with tiles. - to_cover = BoundingBox.from_dict( - reproject_bounding_box(bbox.as_dict(), from_crs=bbox.crs, to_crs=tiling_crs) - ) + to_cover = BoundingBox.from_dict(reproject_bounding_box(bbox.as_dict(), from_crs=bbox.crs, to_crs=tiling_crs)) # Get ranges of tile indices xmin, xmax = [int(math.floor((x - x_offset) / tile_size)) for x in [to_cover.west, to_cover.east]] ymin, ymax = [int(math.floor(y / tile_size)) for y in [to_cover.south, to_cover.north]] @@ -127,13 +122,15 @@ def get_tiles(self, bbox: BoundingBox, max_tiles=MAX_TILES) -> List[BoundingBox] tiles = [] for x in range(xmin, xmax + 1): for y in range(ymin, ymax + 1): - tiles.append(BoundingBox( - west=x * tile_size + x_offset, - south=y * tile_size, - east=(x + 1) * tile_size + x_offset, - north=(y + 1) * tile_size, - crs=tiling_crs, - )) + tiles.append( + BoundingBox( + west=x * tile_size + x_offset, + south=y * tile_size, + east=(x + 1) * tile_size + x_offset, + north=(y + 1) * tile_size, + crs=tiling_crs, + ) + ) return tiles @@ -155,14 +152,9 @@ class TileGridSplitter(AbstractJobSplitter): METADATA_KEY = "_tiling_geometry" def __init__(self, processing: "AggregatorProcessing"): - self.backend_implementation = OpenEoBackendImplementation( - catalog=processing._catalog, - processing=processing - ) + self.backend_implementation = OpenEoBackendImplementation(catalog=processing._catalog, processing=processing) - def split( - self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None - ) -> PartitionedJob: + def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: # TODO: refactor process graph preprocessing and backend_id getting in reusable AbstractJobSplitter method? processing: AggregatorProcessing = self.backend_implementation.processing process_graph = process["process_graph"] @@ -176,17 +168,14 @@ def split( tiles = tile_grid.get_tiles(bbox=global_spatial_extent, max_tiles=job_options.get("max_tiles", MAX_TILES)) inject = self._filter_bbox_injector(process_graph=process_graph) - subjobs = [ - SubJob(process_graph=inject(tile), backend_id=backend_id) - for tile in tiles - ] + subjobs = [SubJob(process_graph=inject(tile), backend_id=backend_id) for tile in tiles] # Store tiling geometry in metadata if metadata is None: metadata = {} metadata[self.METADATA_KEY] = { "global_spatial_extent": global_spatial_extent.as_dict(), - "tiles": [t.as_dict() for t in tiles] + "tiles": [t.as_dict() for t in tiles], } return PartitionedJob( @@ -209,12 +198,17 @@ def _extract_global_spatial_extent(self, process: PGWithMetadata) -> BoundingBox catalog=self.backend_implementation.catalog, processing=ConcreteProcessing(), ) - convert_node(result_node, env=EvalEnv({ - ENV_DRY_RUN_TRACER: dry_run_tracer, - "backend_implementation": backend_implementation, - "version": "1.0.0", # TODO - "user": None, # TODO - })) + convert_node( + result_node, + env=EvalEnv( + { + ENV_DRY_RUN_TRACER: dry_run_tracer, + "backend_implementation": backend_implementation, + "version": "1.0.0", # TODO + "user": None, # TODO + } + ), + ) source_constraints = dry_run_tracer.get_source_constraints() # get global spatial extent spatial_extents = [c["spatial_extent"] for _, c in source_constraints if "spatial_extent" in c] @@ -223,9 +217,7 @@ def _extract_global_spatial_extent(self, process: PGWithMetadata) -> BoundingBox global_extent = BoundingBox.from_dict(spatial_extent_union(*spatial_extents)) return global_extent - def _filter_bbox_injector( - self, process_graph: FlatPG - ) -> typing.Callable[[BoundingBox], dict]: + def _filter_bbox_injector(self, process_graph: FlatPG) -> typing.Callable[[BoundingBox], dict]: """ Build function that takes a bounding box and injects a filter_bbox node just before result the `save_result` node of a "template" process graph. @@ -234,7 +226,7 @@ def _filter_bbox_injector( result_ids = [k for k, v in process_graph.items() if v.get("result")] if len(result_ids) != 1: raise JobSplittingFailure(f"Expected process graph with 1 result node but got {len(result_ids)}") - result_id, = result_ids + (result_id,) = result_ids if process_graph[result_id]["process_id"] != "save_result": raise JobSplittingFailure(f"Expected a save_result node but got {process_graph[result_id]}") previous_node_id = process_graph[result_id]["arguments"]["data"]["from_node"] @@ -250,7 +242,7 @@ def inject(bbox: BoundingBox) -> dict: "arguments": { "data": {"from_node": previous_node_id}, "extent": bbox.as_dict(), - } + }, } new[result_id]["arguments"]["data"] = {"from_node": inject_id} return new @@ -268,7 +260,7 @@ def tiling_geometry_to_geojson(geometry: dict, format: str) -> dict: def reproject(bbox: BoundingBox) -> shapely.geometry.Polygon: polygon = shapely.ops.transform( pyproj.Transformer.from_crs(crs_from=bbox.crs, crs_to="epsg:4326", always_xy=True).transform, - bbox.as_polygon() + bbox.as_polygon(), ) return polygon diff --git a/src/openeo_aggregator/partitionedjobs/tracking.py b/src/openeo_aggregator/partitionedjobs/tracking.py index 60fdca88..a34f36e4 100644 --- a/src/openeo_aggregator/partitionedjobs/tracking.py +++ b/src/openeo_aggregator/partitionedjobs/tracking.py @@ -192,8 +192,11 @@ def create_sjobs(self, user_id: str, pjob_id: str, flask_request: flask.Request) _log.info(f"To create: {pjob_id!r}:{sjob_id!r} (status {sjob_status})") if sjob_status == STATUS_INSERTED: new_status = self._create_sjob( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, - pjob_metadata=pjob_metadata, sjob_metadata=sjob_metadata, + user_id=user_id, + pjob_id=pjob_id, + sjob_id=sjob_id, + pjob_metadata=pjob_metadata, + sjob_metadata=sjob_metadata, flask_request=flask_request, ) create_stats[new_status] += 1 @@ -206,15 +209,20 @@ def create_sjobs(self, user_id: str, pjob_id: str, flask_request: flask.Request) ) def _create_sjob( - self, user_id: str, pjob_id: str, sjob_id: str, - pjob_metadata: dict, sjob_metadata: dict, - flask_request: flask.Request, + self, + user_id: str, + pjob_id: str, + sjob_id: str, + pjob_metadata: dict, + sjob_metadata: dict, + flask_request: flask.Request, ) -> str: try: con = self._backends.get_connection(sjob_metadata["backend_id"]) # TODO: different way to authenticate request? #29 - with con.authenticated_from_request(request=flask_request), \ - con.override(default_timeout=CONNECTION_TIMEOUT_JOB_START): + with con.authenticated_from_request(request=flask_request), con.override( + default_timeout=CONNECTION_TIMEOUT_JOB_START + ): with TimingLogger(title=f"Create {pjob_id}:{sjob_id} on backend {con.id}", logger=_log.info) as timer: job = con.create_job( process_graph=sjob_metadata["process_graph"], @@ -227,16 +235,18 @@ def _create_sjob( _log.info(f"Created {pjob_id}:{sjob_id} on backend {con.id} as batch job {job.job_id}") self._db.set_backend_job_id(user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, job_id=job.job_id) self._db.set_sjob_status( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_CREATED, - message=f"Created in {timer.elapsed}" + user_id=user_id, + pjob_id=pjob_id, + sjob_id=sjob_id, + status=STATUS_CREATED, + message=f"Created in {timer.elapsed}", ) return STATUS_CREATED except Exception as e: # TODO: detect recoverable issue and allow for retry? _log.error(f"Creation of {pjob_id}:{sjob_id} failed", exc_info=True) self._db.set_sjob_status( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_ERROR, - message=f"Create failed: {e}" + user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_ERROR, message=f"Create failed: {e}" ) return STATUS_ERROR @@ -252,7 +262,9 @@ def start_sjobs(self, user_id: str, pjob_id: str, flask_request: flask.Request): _log.info(f"To Start: {pjob_id!r}:{sjob_id!r} (status {sjob_status})") if sjob_status == STATUS_CREATED: new_status = self._start_sjob( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, + user_id=user_id, + pjob_id=pjob_id, + sjob_id=sjob_id, sjob_metadata=sjob_metadata, flask_request=flask_request, ) @@ -266,28 +278,31 @@ def start_sjobs(self, user_id: str, pjob_id: str, flask_request: flask.Request): ) def _start_sjob( - self, user_id: str, pjob_id: str, sjob_id: str, sjob_metadata: dict, flask_request: flask.Request + self, user_id: str, pjob_id: str, sjob_id: str, sjob_metadata: dict, flask_request: flask.Request ) -> str: try: job_id = self._db.get_backend_job_id(user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id) con = self._backends.get_connection(sjob_metadata["backend_id"]) # TODO: different way to authenticate request? #29 - with con.authenticated_from_request(request=flask_request), \ - con.override(default_timeout=CONNECTION_TIMEOUT_JOB_START): + with con.authenticated_from_request(request=flask_request), con.override( + default_timeout=CONNECTION_TIMEOUT_JOB_START + ): with TimingLogger(title=f"Start subjob {sjob_id} on backend {con.id}", logger=_log.info) as timer: job = con.job(job_id) job.start_job() self._db.set_sjob_status( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_RUNNING, - message=f"Started in {timer.elapsed}" + user_id=user_id, + pjob_id=pjob_id, + sjob_id=sjob_id, + status=STATUS_RUNNING, + message=f"Started in {timer.elapsed}", ) return STATUS_RUNNING except Exception as e: # TODO: detect recoverable issue and allow for retry? _log.error(f"Start of {pjob_id}:{sjob_id} failed", exc_info=True) self._db.set_sjob_status( - user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_ERROR, - message=f"Failed to start: {e}" + user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=STATUS_ERROR, message=f"Failed to start: {e}" ) return STATUS_ERROR @@ -472,9 +487,13 @@ def get_logs( log["id"] = f"{sjob_id}-{log.id}" all_logs.extend(logs) except Exception as e: - all_logs.append(LogEntry( - id=f"{sjob_id}-0", level="error", message=f"Failed to get logs of {pjob_id}:{sjob_id}: {e!r}" - )) + all_logs.append( + LogEntry( + id=f"{sjob_id}-0", + level="error", + message=f"Failed to get logs of {pjob_id}:{sjob_id}: {e!r}", + ) + ) return all_logs @@ -494,7 +513,7 @@ class JobAdapter: Adapter for interfaces: `openeo.rest.RestJob`, `openeo.rest.JobResult` """ - def __init__(self, pjob_id: str, connection: 'PartitionedJobConnection'): + def __init__(self, pjob_id: str, connection: "PartitionedJobConnection"): self.pjob_id = pjob_id self.connection = connection @@ -503,7 +522,7 @@ def describe_job(self) -> dict: return self.connection.partitioned_job_tracker.describe_job( user_id=self.connection._user.user_id, pjob_id=self.pjob_id, - flask_request=self.connection._flask_request + flask_request=self.connection._flask_request, ) def start_job(self): @@ -511,7 +530,7 @@ def start_job(self): return self.connection.partitioned_job_tracker.start_sjobs( user_id=self.connection._user.user_id, pjob_id=self.pjob_id, - flask_request=self.connection._flask_request + flask_request=self.connection._flask_request, ) # TODO: also support job cancel and delete. #39 @@ -525,7 +544,7 @@ def get_assets(self) -> List[ResultAsset]: return self.connection.partitioned_job_tracker.get_assets( user_id=self.connection._user.user_id, pjob_id=self.pjob_id, - flask_request=self.connection._flask_request + flask_request=self.connection._flask_request, ) def get_metadata(self) -> dict: diff --git a/src/openeo_aggregator/partitionedjobs/zookeeper.py b/src/openeo_aggregator/partitionedjobs/zookeeper.py index 7a0b73cc..b6a3d15b 100644 --- a/src/openeo_aggregator/partitionedjobs/zookeeper.py +++ b/src/openeo_aggregator/partitionedjobs/zookeeper.py @@ -190,8 +190,7 @@ def set_backend_job_id(self, user_id: str, pjob_id: str, sjob_id: str, job_id: s """ with self._connect(): self._client.create( - path=self._path(user_id, pjob_id, "sjobs", sjob_id, "job_id"), - value=self.serialize(job_id=job_id) + path=self._path(user_id, pjob_id, "sjobs", sjob_id, "job_id"), value=self.serialize(job_id=job_id) ) def get_backend_job_id(self, user_id: str, pjob_id: str, sjob_id: str) -> str: @@ -223,7 +222,7 @@ def set_pjob_status( with self._connect(): kwargs = dict( path=self._path(user_id, pjob_id, "status"), - value=self.serialize(status=status, message=message, timestamp=Clock.time(), progress=progress) + value=self.serialize(status=status, message=message, timestamp=Clock.time(), progress=progress), ) if create: self._client.create(**kwargs) diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py index d4670b17..2d780c79 100644 --- a/src/openeo_aggregator/testing.py +++ b/src/openeo_aggregator/testing.py @@ -104,10 +104,10 @@ class ApproxStr: """Pytest helper in style of `pytest.approx`, but for string checking, based on prefix, body and or suffix""" def __init__( - self, - prefix: Optional[str] = None, - body: Optional[str] = None, - suffix: Optional[str] = None, + self, + prefix: Optional[str] = None, + body: Optional[str] = None, + suffix: Optional[str] = None, ): # TODO: option to do case-insensitive comparison? self.prefix = prefix @@ -115,10 +115,12 @@ def __init__( self.suffix = suffix def __eq__(self, other): - return isinstance(other, str) and \ - (self.prefix is None or other.startswith(self.prefix)) and \ - (self.body is None or self.body in other) and \ - (self.suffix is None or other.endswith(self.suffix)) + return ( + isinstance(other, str) + and (self.prefix is None or other.startswith(self.prefix)) + and (self.body is None or self.body in other) + and (self.suffix is None or other.endswith(self.suffix)) + ) def __repr__(self): return "...".join([self.prefix or ""] + ([self.body] if self.body else []) + [self.suffix or ""]) @@ -153,9 +155,9 @@ def __eq__(self, other): def clock_mock( - start: Union[None, int, float, str, datetime.datetime] = None, - step: float = 0, - offset: Optional[float] = None, + start: Union[None, int, float, str, datetime.datetime] = None, + step: float = 0, + offset: Optional[float] = None, ): """ Mock the `time()` calls in `Clock` with a given start date/time and increment. @@ -180,8 +182,6 @@ def clock_mock( return mock.patch.object(Clock, "_time", new=itertools.count(start, step=step).__next__) - - class MetadataBuilder: """Helper for building openEO/STAC-style metadata dictionaries""" @@ -279,8 +279,7 @@ def process( "id": id, "description": id, "parameters": parameters or [], - "returns": returns - or {"schema": {"type": "object", "subtype": "raster-cube"}}, + "returns": returns or {"schema": {"type": "object", "subtype": "raster-cube"}}, } def processes(self, *args) -> dict: diff --git a/src/openeo_aggregator/utils.py b/src/openeo_aggregator/utils.py index 88717059..071489e3 100644 --- a/src/openeo_aggregator/utils.py +++ b/src/openeo_aggregator/utils.py @@ -52,7 +52,7 @@ def get(self, key: str) -> Iterator: if key in d: yield d[key] - def single_value_for(self, key:str)-> Any: + def single_value_for(self, key: str) -> Any: """Get value for given key and ensure that it is same everywhere""" values = set(self.get(key=key)) if len(values) != 1: @@ -80,15 +80,13 @@ def concat(self, key: str, skip_duplicates=False) -> list: continue result.append(item) else: - _log.warning( - f"MultiDictGetter.concat with {key=}: skipping unexpected type {items}" - ) + _log.warning(f"MultiDictGetter.concat with {key=}: skipping unexpected type {items}") return result def first(self, key, default=None): return next(self.get(key), default) - def select(self, key: str) -> 'MultiDictGetter': + def select(self, key: str) -> "MultiDictGetter": """Create new getter, one step deeper in the dictionary hierarchy.""" return MultiDictGetter(d for d in self.get(key=key) if isinstance(d, dict)) @@ -118,9 +116,7 @@ def dict_merge(*args, **kwargs) -> dict: def drop_dict_keys(data: Any, keys: List[Any]) -> Any: """Recursively drop given keys from (nested) dictionaries""" if isinstance(data, dict): - return { - k: drop_dict_keys(v, keys=keys) for k, v in data.items() if k not in keys - } + return {k: drop_dict_keys(v, keys=keys) for k, v in data.items() if k not in keys} elif isinstance(data, (list, tuple)): return type(data)(drop_dict_keys(v, keys=keys) for v in data) else: @@ -176,7 +172,8 @@ def utcnow(cls) -> datetime.datetime: class BoundingBox(NamedTuple): - """Simple NamedTuple container for a bounding box """ + """Simple NamedTuple container for a bounding box""" + # TODO: move this to openeo_driver west: float south: float @@ -188,11 +185,7 @@ class BoundingBox(NamedTuple): @classmethod def from_dict(cls, d: dict) -> "BoundingBox": - return cls(**{ - k: d[k] - for k in cls._fields - if k not in cls._field_defaults or k in d - }) + return cls(**{k: d[k] for k in cls._fields if k not in cls._field_defaults or k in d}) def as_dict(self) -> dict: return self._asdict() @@ -234,10 +227,7 @@ def common_prefix(lists: Iterable[Iterable[Any]]) -> List[Any]: except StopIteration: prefix = [] for other in list_iterator: - prefix = [ - t[0] - for t in itertools.takewhile(lambda t: t[0] == t[1], zip(prefix, other)) - ] + prefix = [t[0] for t in itertools.takewhile(lambda t: t[0] == t[1], zip(prefix, other))] return prefix @@ -271,9 +261,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if exc_type: self._successive_failures += 1 if self._successive_failures > self._limit: - _log.error( - f"Failure tolerance exceeded ({self._successive_failures} > {self._limit}) with {exc_val!r}" - ) + _log.error(f"Failure tolerance exceeded ({self._successive_failures} > {self._limit}) with {exc_val!r}") # Enough already! return False else: