diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index 138bbbb6..788743d6 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -30,8 +30,17 @@ from openeo_aggregator.metadata.reporter import LoggerReporter from openeo_aggregator.partitionedjobs import PartitionedJob from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter -from openeo_aggregator.partitionedjobs.tracking import PartitionedJobConnection, PartitionedJobTracker -from openeo_aggregator.utils import subdict, dict_merge, normalize_issuer_url +from openeo_aggregator.partitionedjobs.tracking import ( + PartitionedJobConnection, + PartitionedJobTracker, +) +from openeo_aggregator.utils import ( + subdict, + dict_merge, + normalize_issuer_url, + FlatPG, + PGWithMetadata, +) from openeo_driver.ProcessGraphDeserializer import SimpleProcessing from openeo_driver.backend import OpenEoBackendImplementation, AbstractCollectionCatalog, LoadParameters, Processing, \ OidcProvider, BatchJobs, BatchJobMetadata, SecondaryServices, ServiceMetadata @@ -395,7 +404,7 @@ def evaluate(self, process_graph: dict, env: EvalEnv = None): return streaming_flask_response(backend_response, chunk_size=self._stream_chunk_size) - def preprocess_process_graph(self, process_graph: dict, backend_id: str) -> dict: + def preprocess_process_graph(self, process_graph: FlatPG, backend_id: str) -> dict: def preprocess(node: Any) -> Any: if isinstance(node, dict): if "process_id" in node and "arguments" in node: @@ -538,7 +547,12 @@ def _create_job_standard( ) def _create_partitioned_job( - self, user_id: str, process: dict, api_version: str, metadata: dict, job_options: dict = None + self, + user_id: str, + process: PGWithMetadata, + api_version: str, + metadata: dict, + job_options: dict = None, ) -> BatchJobMetadata: """ Advanced/handled batch job creation: diff --git a/src/openeo_aggregator/partitionedjobs/__init__.py b/src/openeo_aggregator/partitionedjobs/__init__.py index 3d347c0c..607d196e 100644 --- a/src/openeo_aggregator/partitionedjobs/__init__.py +++ b/src/openeo_aggregator/partitionedjobs/__init__.py @@ -1,5 +1,6 @@ -from typing import NamedTuple, List +from typing import NamedTuple, List, Dict +from openeo_aggregator.utils import FlatPG, PGWithMetadata from openeo_driver.errors import OpenEOApiException @@ -10,7 +11,7 @@ class PartitionedJobFailure(OpenEOApiException): class SubJob(NamedTuple): """A part of a partitioned job, target at a particular, single back-end.""" # Process graph of the subjob (derived in some way from original parent process graph) - process_graph: dict + process_graph: FlatPG # Id of target backend backend_id: str @@ -18,7 +19,7 @@ class SubJob(NamedTuple): class PartitionedJob(NamedTuple): """A large or multi-back-end job that is split in several sub jobs""" # Original process graph - process: dict + process: PGWithMetadata metadata: dict job_options: dict # List of sub-jobs diff --git a/src/openeo_aggregator/partitionedjobs/splitting.py b/src/openeo_aggregator/partitionedjobs/splitting.py index 9cd28240..bbfe0679 100644 --- a/src/openeo_aggregator/partitionedjobs/splitting.py +++ b/src/openeo_aggregator/partitionedjobs/splitting.py @@ -10,7 +10,7 @@ from openeo.internal.process_graph_visitor import ProcessGraphVisitor from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob, PartitionedJobFailure -from openeo_aggregator.utils import BoundingBox +from openeo_aggregator.utils import BoundingBox, PGWithMetadata, FlatPG from openeo_driver.ProcessGraphDeserializer import convert_node, ENV_DRY_RUN_TRACER, ConcreteProcessing from openeo_driver.backend import OpenEoBackendImplementation from openeo_driver.dry_run import DryRunDataTracer @@ -38,7 +38,9 @@ class AbstractJobSplitter(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def split(self, process: dict, metadata: dict = None, job_options: dict = None) -> PartitionedJob: + def split( + self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None + ) -> PartitionedJob: # TODO: how to express dependencies? give SubJobs an id for referencing? # TODO: how to express combination/aggregation of multiple subjob results as a final result? ... @@ -54,7 +56,9 @@ class FlimsySplitter(AbstractJobSplitter): def __init__(self, processing: "AggregatorProcessing"): self.processing = processing - def split(self, process: dict, metadata: dict = None, job_options: dict = None) -> PartitionedJob: + def split( + self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None + ) -> PartitionedJob: process_graph = process["process_graph"] backend_id = self.processing.get_backend_for_process_graph(process_graph=process_graph, api_version="TODO") process_graph = self.processing.preprocess_process_graph(process_graph, backend_id=backend_id) @@ -142,7 +146,9 @@ def __init__(self, processing: "AggregatorProcessing"): processing=processing ) - def split(self, process: dict, metadata: dict = None, job_options: dict = None) -> PartitionedJob: + def split( + self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None + ) -> PartitionedJob: # TODO: refactor process graph preprocessing and backend_id getting in reusable AbstractJobSplitter method? processing: AggregatorProcessing = self.backend_implementation.processing process_graph = process["process_graph"] @@ -171,7 +177,7 @@ def split(self, process: dict, metadata: dict = None, job_options: dict = None) return PartitionedJob(process=process, metadata=metadata, job_options=job_options, subjobs=subjobs) - def _extract_global_spatial_extent(self, process: dict) -> BoundingBox: + def _extract_global_spatial_extent(self, process: PGWithMetadata) -> BoundingBox: """Extract global spatial extent from given process graph""" # TODO: drop deepcopy when `dereference_from_node_arguments` doesn't do # in-place manipulation of original process dict anymore @@ -198,7 +204,9 @@ def _extract_global_spatial_extent(self, process: dict) -> BoundingBox: global_extent = BoundingBox.from_dict(spatial_extent_union(*spatial_extents)) return global_extent - def _filter_bbox_injector(self, process_graph: dict) -> typing.Callable[[BoundingBox], dict]: + def _filter_bbox_injector( + self, process_graph: FlatPG + ) -> typing.Callable[[BoundingBox], dict]: """ Build function that takes a bounding box and injects a filter_bbox node just before result the `save_result` node of a "template" process graph. diff --git a/src/openeo_aggregator/utils.py b/src/openeo_aggregator/utils.py index 1ba3b2b7..5bf9046c 100644 --- a/src/openeo_aggregator/utils.py +++ b/src/openeo_aggregator/utils.py @@ -17,6 +17,17 @@ _log = logging.getLogger(__name__) +# Type-hinting alias for "process graph with metadata" constructs: +# containing at least a "process_graph" field with a process graph in "flat-graph" representation +# TODO move this upstream to openeo-python-driver +PGWithMetadata = dict + + +# Type-hinting alias for process graphs in "flat-graph" representation. +# TODO move this upstream to openeo-python-driver +FlatPG = dict + + class MultiDictGetter: """ Helper to get (and combine) items (where available) from a collection of dictionaries.