diff --git a/src/openeo_aggregator/partitionedjobs/__init__.py b/src/openeo_aggregator/partitionedjobs/__init__.py index 52fdc7b8..85671ba3 100644 --- a/src/openeo_aggregator/partitionedjobs/__init__.py +++ b/src/openeo_aggregator/partitionedjobs/__init__.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, List, Dict, Sequence, Union, Any +from typing import NamedTuple, List, Dict, Sequence, Union, Any, Optional from openeo_aggregator.utils import FlatPG, PGWithMetadata from openeo_driver.errors import OpenEOApiException @@ -12,8 +12,8 @@ class SubJob(NamedTuple): """A part of a partitioned job, target at a particular, single back-end.""" # Process graph of the subjob (derived in some way from original parent process graph) process_graph: FlatPG - # Id of target backend - backend_id: str + # Id of target backend (or None if there is no dedicated backend) + backend_id: Optional[str] class PartitionedJob(NamedTuple): diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 0941a43b..18905bda 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -1,6 +1,5 @@ import collections import logging -from pprint import pprint from typing import Callable, Dict, List from openeo_aggregator.partitionedjobs import SubJob, PartitionedJob @@ -35,7 +34,7 @@ def split( f"Extracted backend usage from `load_collection` nodes: {backend_usage}" ) - primary_backend = backend_usage.most_common(1)[0][0] + primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None secondary_backends = {b for b in backend_usage if b != primary_backend} _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") @@ -80,45 +79,3 @@ def split( subjobs=PartitionedJob.to_subjobs_dict(subjobs), dependencies=dependencies, ) - - -def main(): - # Simple proof of concept for cross-backend splitting - process_graph = { - "lc1": {"process_id": "load_collection", "arguments": {"id": "VITO_1"}}, - "lc2": {"process_id": "load_collection", "arguments": {"id": "SH_1"}}, - "mc1": { - "process_id": "merge_cubes", - "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, - }, - "sr1": { - "process_id": "save_result", - "arguments": {"format": "NetCDF"}, - }, - } - print("Original PG:") - pprint(process_graph) - - splitter = CrossBackendSplitter( - backend_for_collection=lambda cid: cid.split("_")[0] - ) - - pjob = splitter.split({"process_graph": process_graph}) - - def namedtuples_to_dict(x): - """Walk data structure and convert namedtuples to dictionaries""" - if hasattr(x, "_asdict"): - return namedtuples_to_dict(x._asdict()) - elif isinstance(x, list): - return [namedtuples_to_dict(i) for i in x] - elif isinstance(x, dict): - return {k: namedtuples_to_dict(v) for k, v in x.items()} - else: - return x - - print("Cross-backend split:") - pprint(namedtuples_to_dict(pjob), width=120) - - -if __name__ == "__main__": - main() diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py new file mode 100644 index 00000000..f34dabc0 --- /dev/null +++ b/tests/partitionedjobs/test_crossbackend.py @@ -0,0 +1,78 @@ +from openeo_aggregator.partitionedjobs import SubJob +from openeo_aggregator.partitionedjobs.crossbackend import ( + CrossBackendSplitter, +) + + +class TestCrossBackendSplitter: + def test_simple(self): + process_graph = { + "add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True} + } + splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + res = splitter.split({"process_graph": process_graph}) + + assert res.subjobs == {"primary": SubJob(process_graph, backend_id=None)} + assert res.dependencies == {"primary": []} + + def test_basic(self): + process_graph = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, + "mc1": { + "process_id": "merge_cubes", + "arguments": { + "cube1": {"from_node": "lc1"}, + "cube2": {"from_node": "lc2"}, + }, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"format": "NetCDF"}, + }, + } + splitter = CrossBackendSplitter( + backend_for_collection=lambda cid: cid.split("_")[0] + ) + res = splitter.split({"process_graph": process_graph}) + + assert res.subjobs == { + "primary": SubJob( + process_graph={ + "lc1": { + "process_id": "load_collection", + "arguments": {"id": "B1_NDVI"}, + }, + "lc2": { + "process_id": "load_result", + "arguments": {"id": "placeholder:B2:lc2"}, + }, + "mc1": { + "process_id": "merge_cubes", + "arguments": { + "cube1": {"from_node": "lc1"}, + "cube2": {"from_node": "lc2"}, + }, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"format": "NetCDF"}, + }, + }, + backend_id="B1", + ), + "B2:lc2": SubJob( + process_graph={ + "lc2": { + "process_id": "load_collection", + "arguments": {"id": "B2_FAPAR"}, + }, + "sr1": { + "process_id": "save_result", + "arguments": {"format": "NetCDF"}, + }, + }, + backend_id="B2", + ), + } + assert res.dependencies == {"primary": ["B2:lc2"]}