diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index acd058e3..254e9d9c 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -5,8 +5,20 @@ import pathlib import re import time +import typing from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Union, +) import flask import openeo_driver.util.view_helpers @@ -198,7 +210,7 @@ def evaluate(backend_id, pg): def get_backend_candidates_for_collections(self, collections: Iterable[str]) -> List[str]: """ - Get best backend id providing all given collections + Get backend ids providing all given collections :param collections: list/set of collection ids :return: """ @@ -324,15 +336,18 @@ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> Pr # TODO: only check for mismatch in major version? _log.warning(f"API mismatch: requested {api_version} != upstream {self.backends.api_version}") - combined_processes = self._memoizer.get_or_call( - key=("all", str(api_version)), - callback=self._get_merged_process_metadata, - ) + combined_processes = self.get_merged_process_metadata() process_registry = ProcessRegistry() for pid, spec in combined_processes.items(): process_registry.add_spec(spec=spec) return process_registry + def get_merged_process_metadata(self) -> Dict[str, dict]: + return self._memoizer.get_or_call( + key=("all", str(self.backends.api_version)), + callback=self._get_merged_process_metadata, + ) + def _get_merged_process_metadata(self) -> dict: processes_per_backend = {} for con in self.backends: @@ -347,6 +362,30 @@ def _get_merged_process_metadata(self) -> dict: ) return combined_processes + def _get_backend_candidates_for_processes( + self, processes: typing.Collection[str] + ) -> Union[List[str], None]: + """ + Get backend ids providing all given processes + :param processes: collection process ids + :return: + """ + processes = set(processes) + process_metadata = self.get_merged_process_metadata() + candidates: Union[Set[str], None] = None + for pid in processes: + if pid in process_metadata: + backends = process_metadata[pid][STAC_PROPERTY_FEDERATION_BACKENDS] + if candidates is None: + candidates = set(backends) + else: + candidates = candidates.intersection(backends) + else: + _log.warning( + f"Skipping unknown process {pid!r} in `_get_backend_candidates_for_processes`" + ) + return candidates + def get_backend_for_process_graph( self, process_graph: dict, api_version: str, job_options: Optional[dict] = None ) -> str: @@ -367,12 +406,13 @@ def get_backend_for_process_graph( ) return bid - # TODO: also check used processes? collections = set() collection_backend_constraints = [] + processes = set() try: for pg_node in process_graph.values(): process_id = pg_node["process_id"] + processes.add(process_id) arguments = pg_node["arguments"] if process_id == "load_collection": collections.add(arguments["id"]) @@ -414,13 +454,21 @@ def get_backend_for_process_graph( collection_candidates = self._catalog.get_backend_candidates_for_collections(collections=collections) backend_candidates = [b for b in backend_candidates if b in collection_candidates] - if collection_backend_constraints: - conditions = self._catalog.generate_backend_constraint_callables( - process_graphs=collection_backend_constraints - ) - backend_candidates = [b for b in backend_candidates if all(c(b) for c in conditions)] + if collection_backend_constraints: + conditions = self._catalog.generate_backend_constraint_callables( + process_graphs=collection_backend_constraints + ) + backend_candidates = [ + b for b in backend_candidates if all(c(b) for c in conditions) + ] + + if processes: + process_candidates = self._get_backend_candidates_for_processes(processes) + backend_candidates = [ + b for b in backend_candidates if b in process_candidates + ] - if len(backend_candidates) > 1: + if len(backend_candidates) > 1: # TODO #42 Check `/validation` instead of naively picking first one? _log.warning( f"Multiple back-end candidates {backend_candidates} for collections {collections}." diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py index 0ba826a6..4e8ff95b 100644 --- a/src/openeo_aggregator/testing.py +++ b/src/openeo_aggregator/testing.py @@ -203,3 +203,40 @@ def build_capabilities( ) return capabilities + + +class MetadataBuilder: + """Helper for building openEO/STAC-style metadata dictionaries""" + + def collection(self, id="S2", *, license="proprietary") -> dict: + """Build collection metadata""" + return { + "id": id, + "license": license, + "stac_version": "1.0.0", + "description": id, + "extent": { + "spatial": {"bbox": [[2, 50, 5, 55]]}, + "temporal": {"interval": [["2017-01-01T00:00:00Z", None]]}, + }, + "links": [ + { + "rel": "license", + "href": "https://oeoa.test/licence", + } + ], + } + + def collections(self, *args) -> dict: + """Build `GET /collections` metadata""" + collections = [] + for arg in args: + if isinstance(arg, str): + collection = self.collection(id=arg) + elif isinstance(arg, dict): + collection = self.collection(**arg) + else: + raise ValueError(arg) + collections.append(collection) + + return {"collections": collections, "links": []} diff --git a/tests/conftest.py b/tests/conftest.py index 9401f26e..cf5d38a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,11 @@ MultiBackendConnection, ) from openeo_aggregator.config import AggregatorConfig -from openeo_aggregator.testing import DummyKazooClient, build_capabilities +from openeo_aggregator.testing import ( + DummyKazooClient, + MetadataBuilder, + build_capabilities, +) @pytest.fixture @@ -179,3 +183,8 @@ def catalog(multi_backend_connection, config) -> AggregatorCollectionCatalog: backends=multi_backend_connection, config=config ) + + +@pytest.fixture +def bldr() -> MetadataBuilder: + return MetadataBuilder() diff --git a/tests/test_views.py b/tests/test_views.py index dd9e8acd..9cb114d7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -986,6 +986,85 @@ def b1_post_result(request: requests.Request, context): assert res.json == 111 assert (b1_mock.call_count, b2_mock.call_count) == (1, 0) + @pytest.mark.parametrize( + ["process_id", "call_counts", "expected_warnings"], + [ + ( + "blargh", + (1, 0), + [ + RegexMatcher( + "Multiple back-end candidates.*Naively picking first one" + ) + ], + ), + ("wibble", (1, 0), []), + ("snorfle", (0, 1), []), + ( + "frobnicate", + (1, 0), + [ + RegexMatcher("Skipping unknown process 'frobnicate'"), + RegexMatcher( + "Multiple back-end candidates.*Naively picking first one" + ), + ], + ), + ], + ) + def test_result_backend_by_process( + self, + api100, + requests_mock, + backend1, + backend2, + process_id, + call_counts, + caplog, + bldr, + expected_warnings, + ): + requests_mock.get(backend1 + "/collections", json=bldr.collections("S2")) + common_processes = [{"id": "load_collection"}, {"id": "blargh"}] + requests_mock.get( + backend1 + "/processes", + json={"processes": common_processes + [{"id": "wibble"}]}, + ) + requests_mock.get(backend2 + "/collections", json=bldr.collections("S2")) + requests_mock.get( + backend2 + "/processes", + json={"processes": common_processes + [{"id": "snorfle"}]}, + ) + + def post_result(request: requests.Request, context): + assert ( + request.headers["Authorization"] + == TEST_USER_AUTH_HEADER["Authorization"] + ) + assert request.json()["process"]["process_graph"] == pg + context.headers["Content-Type"] = "application/json" + return 123 + + b1_mock = requests_mock.post(backend1 + "/result", json=post_result) + b2_mock = requests_mock.post(backend2 + "/result", json=post_result) + api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) + pg = { + "lc": { + "process_id": "load_collection", + "arguments": {"id": "S2"}, + }, + "process": { + "process_id": process_id, + "arguments": {"data": {"from_node": "lc"}, "factor": 5}, + "result": True, + }, + } + request = {"process": {"process_graph": pg}} + res = api100.post("/result", json=request).assert_status_code(200) + assert res.json == 123 + assert (b1_mock.call_count, b2_mock.call_count) == call_counts + assert caplog.messages == expected_warnings + class TestBatchJobs: