From 982d1f8b5b2b7414e31780d3410265f787715ac9 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Mon, 4 Mar 2024 15:00:11 +0100 Subject: [PATCH] Issue #135 add config option to inject job_options --- CHANGELOG.md | 4 ++++ src/openeo_aggregator/about.py | 2 +- src/openeo_aggregator/backend.py | 5 ++++ src/openeo_aggregator/config.py | 12 +++++++++- tests/test_views.py | 39 +++++++++++++++++++++++++++++++- 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a19b1..4eca311 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [0.29.0] + +- Add config option to inject job options before sending a job to upstream back-end ([#135](https://github.com/Open-EO/openeo-aggregator/issues/135)) + ## [0.28.0] - Remove (now unused) `AggregatorConfig` class definition ([#112](https://github.com/Open-EO/openeo-aggregator/issues/112)) diff --git a/src/openeo_aggregator/about.py b/src/openeo_aggregator/about.py index fc06c27..1a7a861 100644 --- a/src/openeo_aggregator/about.py +++ b/src/openeo_aggregator/about.py @@ -2,7 +2,7 @@ import sys from typing import Optional -__version__ = "0.28.0a1" +__version__ = "0.29.0a1" def log_version_info(logger: Optional[logging.Logger] = None): diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index 241d1b8..ca01f21 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -717,11 +717,16 @@ def _create_job_standard( job_options=job_options, ) process_graph = self.processing.preprocess_process_graph(process_graph, backend_id=backend_id) + if job_options: additional = {k: v for k, v in job_options.items() if not k.startswith("_agg_")} else: additional = None + if get_backend_config().job_options_update: + # Allow fine-tuning job options through config + additional = get_backend_config().job_options_update(job_options=additional, backend_id=backend_id) + con = self.backends.get_connection(backend_id) with con.authenticated_from_request(request=flask.request, user=User(user_id=user_id)), con.override( default_timeout=CONNECTION_TIMEOUT_JOB_START diff --git a/src/openeo_aggregator/config.py b/src/openeo_aggregator/config.py index 40ba025..946dd79 100644 --- a/src/openeo_aggregator/config.py +++ b/src/openeo_aggregator/config.py @@ -1,7 +1,7 @@ import logging import os import re -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Protocol, Union import attrs from openeo_driver.config import OpenEoBackendConfig @@ -30,6 +30,14 @@ class ConfigException(ValueError): pass +class JobOptionsUpdater(Protocol): + """API for `job_options_update` config (callable)""" + + def __call__(self, job_options: dict, backend_id: str) -> dict: + """Return updated job options dict""" + ... + + @attrs.frozen(kw_only=True) class AggregatorBackendConfig(OpenEoBackendConfig): @@ -60,6 +68,8 @@ class AggregatorBackendConfig(OpenEoBackendConfig): zk_memoizer_tracking: bool = smart_bool(os.environ.get("OPENEO_AGGREGATOR_ZK_MEMOIZER_TRACKING")) + job_options_update: Optional[JobOptionsUpdater] = None + # Internal singleton _config_getter = ConfigGetter(expected_class=AggregatorBackendConfig) diff --git a/tests/test_views.py b/tests/test_views.py index 99eeecd..3fe0cae 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1583,6 +1583,43 @@ def post_jobs(request: requests.Request, context): assert res.headers["Location"] == "http://oeoa.test/openeo/1.0.0/jobs/b1-th3j0b" assert res.headers["OpenEO-Identifier"] == "b1-th3j0b" + @pytest.mark.parametrize( + ["job_options_update", "expected_job_options"], + [ + (None, {"side": "salad"}), + ( + lambda job_options, backend_id: {**job_options, **{"beverage": f"fizzy{backend_id}"}}, + {"side": "salad", "beverage": "fizzyb1"}, + ), + ], + ) + def test_create_job_options_update(self, api100, requests_mock, backend1, job_options_update, expected_job_options): + requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S2"}]}) + + def post_jobs(request: requests.Request, context): + assert request.json() == { + "process": {"process_graph": pg}, + "job_options": expected_job_options, + } + context.headers["Location"] = backend1 + "/jobs/th3j0b" + context.headers["OpenEO-Identifier"] = "th3j0b" + context.status_code = 201 + + requests_mock.post(backend1 + "/jobs", text=post_jobs) + + pg = {"lc": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True}} + api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) + with config_overrides(job_options_update=job_options_update): + res = api100.post( + "/jobs", + json={ + "process": {"process_graph": pg}, + "job_options": {"side": "salad"}, + }, + ).assert_status_code(201) + assert res.headers["Location"] == "http://oeoa.test/openeo/1.0.0/jobs/b1-th3j0b" + assert res.headers["OpenEO-Identifier"] == "b1-th3j0b" + @pytest.mark.parametrize( "body", [ @@ -2104,7 +2141,7 @@ def service_metadata_wmts_foo(self): enabled=True, configuration={"version": "0.5.8"}, attributes={}, - title="Test WMTS service" + title="Test WMTS service", # not setting "created": This is used to test creating a service. )