Skip to content

Commit

Permalink
Issue #135 add config option to inject job_options
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Mar 4, 2024
1 parent 4cfd784 commit 982d1f8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.29.0]

- Add config option to inject job options before sending a job to upstream back-end ([#135](https://github.com/Open-EO/openeo-aggregator/issues/135))

## [0.28.0]

- Remove (now unused) `AggregatorConfig` class definition ([#112](https://github.com/Open-EO/openeo-aggregator/issues/112))
Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/about.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Optional

__version__ = "0.28.0a1"
__version__ = "0.29.0a1"


def log_version_info(logger: Optional[logging.Logger] = None):
Expand Down
5 changes: 5 additions & 0 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,16 @@ def _create_job_standard(
job_options=job_options,
)
process_graph = self.processing.preprocess_process_graph(process_graph, backend_id=backend_id)

if job_options:
additional = {k: v for k, v in job_options.items() if not k.startswith("_agg_")}
else:
additional = None

if get_backend_config().job_options_update:
# Allow fine-tuning job options through config
additional = get_backend_config().job_options_update(job_options=additional, backend_id=backend_id)

con = self.backends.get_connection(backend_id)
with con.authenticated_from_request(request=flask.request, user=User(user_id=user_id)), con.override(
default_timeout=CONNECTION_TIMEOUT_JOB_START
Expand Down
12 changes: 11 additions & 1 deletion src/openeo_aggregator/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import re
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Protocol, Union

import attrs
from openeo_driver.config import OpenEoBackendConfig
Expand Down Expand Up @@ -30,6 +30,14 @@ class ConfigException(ValueError):
pass


class JobOptionsUpdater(Protocol):
"""API for `job_options_update` config (callable)"""

def __call__(self, job_options: dict, backend_id: str) -> dict:
"""Return updated job options dict"""
...


@attrs.frozen(kw_only=True)
class AggregatorBackendConfig(OpenEoBackendConfig):

Expand Down Expand Up @@ -60,6 +68,8 @@ class AggregatorBackendConfig(OpenEoBackendConfig):

zk_memoizer_tracking: bool = smart_bool(os.environ.get("OPENEO_AGGREGATOR_ZK_MEMOIZER_TRACKING"))

job_options_update: Optional[JobOptionsUpdater] = None


# Internal singleton
_config_getter = ConfigGetter(expected_class=AggregatorBackendConfig)
Expand Down
39 changes: 38 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,43 @@ def post_jobs(request: requests.Request, context):
assert res.headers["Location"] == "http://oeoa.test/openeo/1.0.0/jobs/b1-th3j0b"
assert res.headers["OpenEO-Identifier"] == "b1-th3j0b"

@pytest.mark.parametrize(
["job_options_update", "expected_job_options"],
[
(None, {"side": "salad"}),
(
lambda job_options, backend_id: {**job_options, **{"beverage": f"fizzy{backend_id}"}},
{"side": "salad", "beverage": "fizzyb1"},
),
],
)
def test_create_job_options_update(self, api100, requests_mock, backend1, job_options_update, expected_job_options):
requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S2"}]})

def post_jobs(request: requests.Request, context):
assert request.json() == {
"process": {"process_graph": pg},
"job_options": expected_job_options,
}
context.headers["Location"] = backend1 + "/jobs/th3j0b"
context.headers["OpenEO-Identifier"] = "th3j0b"
context.status_code = 201

requests_mock.post(backend1 + "/jobs", text=post_jobs)

pg = {"lc": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True}}
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
with config_overrides(job_options_update=job_options_update):
res = api100.post(
"/jobs",
json={
"process": {"process_graph": pg},
"job_options": {"side": "salad"},
},
).assert_status_code(201)
assert res.headers["Location"] == "http://oeoa.test/openeo/1.0.0/jobs/b1-th3j0b"
assert res.headers["OpenEO-Identifier"] == "b1-th3j0b"

@pytest.mark.parametrize(
"body",
[
Expand Down Expand Up @@ -2104,7 +2141,7 @@ def service_metadata_wmts_foo(self):
enabled=True,
configuration={"version": "0.5.8"},
attributes={},
title="Test WMTS service"
title="Test WMTS service",
# not setting "created": This is used to test creating a service.
)

Expand Down

0 comments on commit 982d1f8

Please sign in to comment.