From 329469b3de72b3d13aa460f47245031b4be4a6d1 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Wed, 6 Sep 2023 10:17:25 +0200 Subject: [PATCH] Issue #115 CrossBackendSplitter: add "streamed" split to allow injecting batch job ids on the fly --- .../partitionedjobs/crossbackend.py | 118 +++++++++---- tests/partitionedjobs/test_crossbackend.py | 165 +++++++++++++++++- 2 files changed, 247 insertions(+), 36 deletions(-) diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 36d1c169..0a8a83b5 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -5,7 +5,7 @@ import logging import time from contextlib import nullcontext -from typing import Callable, Dict, List, Sequence +from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple import openeo from openeo import BatchJob @@ -20,6 +20,42 @@ _LOAD_RESULT_PLACEHOLDER = "_placeholder:" +# Some type annotation aliases to make things more self-documenting +SubGraphId = str + + +class GetReplacementCallable(Protocol): + """ + Type annotation for callback functions that produce a node replacement + for a node that is split off from the main process graph + + Also see `_default_get_replacement` + """ + + def __call__(self, node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: + """ + :param node_id: original id of the node in the process graph (e.g. `loadcollection2`) + :param node: original node in the process graph (e.g. `{"process_id": "load_collection", "arguments": {...}}`) + :param subgraph_id: id of the corresponding dependency subgraph + (to be handled as opaque id, but possibly something like `backend1:loadcollection2`) + + :return: new process graph nodes. Should contain at least a node keyed under `node_id` + """ + ... + + +def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: + """ + Default `get_replacement` function to replace a node that has been split off. + """ + return { + node_id: { + # TODO: use `load_stac` iso `load_result` + "process_id": "load_result", + "arguments": {"id": f"{_LOAD_RESULT_PLACEHOLDER}{subgraph_id}"}, + } + } + class CrossBackendSplitter(AbstractJobSplitter): """ @@ -42,10 +78,25 @@ def __init__( self.backend_for_collection = backend_for_collection self._always_split = always_split - def split( - self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None - ) -> PartitionedJob: - process_graph = process["process_graph"] + def split_streaming( + self, + process_graph: FlatPG, + get_replacement: GetReplacementCallable = _default_get_replacement, + ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]: + """ + Split given process graph in sub-process graphs and return these as an iterator + in an order so that a subgraph comes after all subgraphs it depends on + (e.g. main "primary" graph comes last). + + The iterator approach allows working with a dynamic `get_replacement` implementation + that adapting to on previously produced subgraphs + (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately). + + :return: tuple containing: + - subgraph id + - SubJob + - dependencies as list of subgraph ids + """ # Extract necessary back-ends from `load_collection` usage backend_per_collection: Dict[str, str] = { @@ -57,55 +108,60 @@ def split( backend_usage = collections.Counter(backend_per_collection.values()) _log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}") + # TODO: more options to determine primary backend? primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None secondary_backends = {b for b in backend_usage if b != primary_backend} _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") primary_id = "main" - primary_pg = SubJob(process_graph={}, backend_id=primary_backend) + primary_pg = {} primary_has_load_collection = False - - subjobs: Dict[str, SubJob] = {primary_id: primary_pg} - dependencies: Dict[str, List[str]] = {primary_id: []} + primary_dependencies = [] for node_id, node in process_graph.items(): if node["process_id"] == "load_collection": bid = backend_per_collection[node["arguments"]["id"]] - if bid == primary_backend and not ( - self._always_split and primary_has_load_collection - ): + if bid == primary_backend and (not self._always_split or not primary_has_load_collection): # Add to primary pg - primary_pg.process_graph[node_id] = node + primary_pg[node_id] = node primary_has_load_collection = True else: # New secondary pg - pg = { + sub_id = f"{bid}:{node_id}" + sub_pg = { node_id: node, "sr1": { # TODO: other/better choices for save_result format (e.g. based on backend support)? - # TODO: particular format options? "process_id": "save_result", "arguments": { "data": {"from_node": node_id}, + # TODO: particular format options? # "format": "NetCDF", "format": "GTiff", }, "result": True, }, } - dependency_id = f"{bid}:{node_id}" - subjobs[dependency_id] = SubJob(process_graph=pg, backend_id=bid) - dependencies[primary_id].append(dependency_id) - # Link to primary pg with load_result - primary_pg.process_graph[node_id] = { - # TODO: encapsulate this placeholder process/id better? - "process_id": "load_result", - "arguments": { - "id": f"{_LOAD_RESULT_PLACEHOLDER}{dependency_id}" - }, - } + + yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), []) + + # Link secondary pg into primary pg + primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id)) + primary_dependencies.append(sub_id) else: - primary_pg.process_graph[node_id] = node + primary_pg[node_id] = node + + yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies) + + def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: + """Split given process graph into a `PartitionedJob`""" + + subjobs: Dict[SubGraphId, SubJob] = {} + dependencies: Dict[SubGraphId, List[SubGraphId]] = {} + for sub_id, subjob, sub_dependencies in self.split_streaming(process_graph=process["process_graph"]): + subjobs[sub_id] = subjob + if sub_dependencies: + dependencies[sub_id] = sub_dependencies return PartitionedJob( process=process, @@ -116,9 +172,7 @@ def split( ) -def resolve_dependencies( - process_graph: FlatPG, batch_jobs: Dict[str, BatchJob] -) -> FlatPG: +def _resolve_dependencies(process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]) -> FlatPG: """ Replace placeholders in given process graph based on given subjob_id to batch_job_id mapping. @@ -235,9 +289,7 @@ def run_partitioned_job( # Handle job (start, poll status, ...) if states[subjob_id] == SUBJOB_STATES.READY: try: - process_graph = resolve_dependencies( - subjob.process_graph, batch_jobs=batch_jobs - ) + process_graph = _resolve_dependencies(subjob.process_graph, batch_jobs=batch_jobs) _log.info( f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}" diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 8d1e2c82..9c26740e 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -1,5 +1,6 @@ import dataclasses import re +import types from typing import Dict, List, Optional from unittest import mock @@ -13,12 +14,13 @@ from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.crossbackend import ( CrossBackendSplitter, + SubGraphId, run_partitioned_job, ) class TestCrossBackendSplitter: - def test_simple(self): + def test_split_simple(self): process_graph = { "add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True} } @@ -26,9 +28,16 @@ def test_simple(self): res = splitter.split({"process_graph": process_graph}) assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)} - assert res.dependencies == {"main": []} + assert res.dependencies == {} - def test_basic(self): + def test_split_streaming_simple(self): + process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} + splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + res = splitter.split_streaming(process_graph) + assert isinstance(res, types.GeneratorType) + assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])] + + def test_split_basic(self): process_graph = { "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, @@ -93,6 +102,156 @@ def test_basic(self): } assert res.dependencies == {"main": ["B2:lc2"]} + def test_split_streaming_basic(self): + process_graph = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, + "mc1": { + "process_id": "merge_cubes", + "arguments": { + "cube1": {"from_node": "lc1"}, + "cube2": {"from_node": "lc2"}, + }, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, + "result": True, + }, + } + splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + result = splitter.split_streaming(process_graph) + assert isinstance(result, types.GeneratorType) + + assert list(result) == [ + ( + "B2:lc2", + SubJob( + process_graph={ + "lc2": { + "process_id": "load_collection", + "arguments": {"id": "B2_FAPAR"}, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, + "result": True, + }, + }, + backend_id="B2", + ), + [], + ), + ( + "main", + SubJob( + process_graph={ + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "lc2": {"process_id": "load_result", "arguments": {"id": "_placeholder:B2:lc2"}}, + "mc1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, + "result": True, + }, + }, + backend_id="B1", + ), + ["B2:lc2"], + ), + ] + + def test_split_streaming_get_replacement(self): + process_graph = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, + "lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}}, + "merge": { + "process_id": "merge", + "arguments": { + "cube1": {"from_node": "lc1"}, + "cube2": {"from_node": "lc2"}, + "cube3": {"from_node": "lc3"}, + }, + "result": True, + }, + } + splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + + batch_jobs = {} + + def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: + return { + node_id: { + "process_id": "load_batch_job", + "arguments": {"batch_job": batch_jobs[subgraph_id]}, + } + } + + substream = splitter.split_streaming(process_graph, get_replacement=get_replacement) + + result = [] + for subgraph_id, subjob, dependencies in substream: + batch_jobs[subgraph_id] = f"job-{111 * (len(batch_jobs) + 1)}" + result.append((subgraph_id, subjob, dependencies)) + + assert list(result) == [ + ( + "B2:lc2", + SubJob( + process_graph={ + "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, + "sr1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, + "result": True, + }, + }, + backend_id="B2", + ), + [], + ), + ( + "B3:lc3", + SubJob( + process_graph={ + "lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}}, + "sr1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"}, + "result": True, + }, + }, + backend_id="B3", + ), + [], + ), + ( + "main", + SubJob( + process_graph={ + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "lc2": {"process_id": "load_batch_job", "arguments": {"batch_job": "job-111"}}, + "lc3": {"process_id": "load_batch_job", "arguments": {"batch_job": "job-222"}}, + "merge": { + "process_id": "merge", + "arguments": { + "cube1": {"from_node": "lc1"}, + "cube2": {"from_node": "lc2"}, + "cube3": {"from_node": "lc3"}, + }, + "result": True, + }, + }, + backend_id="B1", + ), + ["B2:lc2", "B3:lc3"], + ), + ] + @dataclasses.dataclass class _FakeJob: