Skip to content

Commit

Permalink
Issue #100 add process based backend candidate selection
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 4, 2023
1 parent a17bc23 commit f761060
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 14 deletions.
74 changes: 61 additions & 13 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,20 @@
import pathlib
import re
import time
import typing
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)

import flask
import openeo_driver.util.view_helpers
Expand Down Expand Up @@ -198,7 +210,7 @@ def evaluate(backend_id, pg):

def get_backend_candidates_for_collections(self, collections: Iterable[str]) -> List[str]:
"""
Get best backend id providing all given collections
Get backend ids providing all given collections
:param collections: list/set of collection ids
:return:
"""
Expand Down Expand Up @@ -324,15 +336,18 @@ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> Pr
# TODO: only check for mismatch in major version?
_log.warning(f"API mismatch: requested {api_version} != upstream {self.backends.api_version}")

combined_processes = self._memoizer.get_or_call(
key=("all", str(api_version)),
callback=self._get_merged_process_metadata,
)
combined_processes = self.get_merged_process_metadata()
process_registry = ProcessRegistry()
for pid, spec in combined_processes.items():
process_registry.add_spec(spec=spec)
return process_registry

def get_merged_process_metadata(self) -> Dict[str, dict]:
return self._memoizer.get_or_call(
key=("all", str(self.backends.api_version)),
callback=self._get_merged_process_metadata,
)

def _get_merged_process_metadata(self) -> dict:
processes_per_backend = {}
for con in self.backends:
Expand All @@ -347,6 +362,30 @@ def _get_merged_process_metadata(self) -> dict:
)
return combined_processes

def _get_backend_candidates_for_processes(
self, processes: typing.Collection[str]
) -> Union[List[str], None]:
"""
Get backend ids providing all given processes
:param processes: collection process ids
:return:
"""
processes = set(processes)
process_metadata = self.get_merged_process_metadata()
candidates: Union[Set[str], None] = None
for pid in processes:
if pid in process_metadata:
backends = process_metadata[pid][STAC_PROPERTY_FEDERATION_BACKENDS]
if candidates is None:
candidates = set(backends)
else:
candidates = candidates.intersection(backends)
else:
_log.warning(
f"Skipping unknown process {pid!r} in `_get_backend_candidates_for_processes`"
)
return candidates

def get_backend_for_process_graph(
self, process_graph: dict, api_version: str, job_options: Optional[dict] = None
) -> str:
Expand All @@ -367,12 +406,13 @@ def get_backend_for_process_graph(
)
return bid

# TODO: also check used processes?
collections = set()
collection_backend_constraints = []
processes = set()
try:
for pg_node in process_graph.values():
process_id = pg_node["process_id"]
processes.add(process_id)
arguments = pg_node["arguments"]
if process_id == "load_collection":
collections.add(arguments["id"])
Expand Down Expand Up @@ -414,13 +454,21 @@ def get_backend_for_process_graph(
collection_candidates = self._catalog.get_backend_candidates_for_collections(collections=collections)
backend_candidates = [b for b in backend_candidates if b in collection_candidates]

if collection_backend_constraints:
conditions = self._catalog.generate_backend_constraint_callables(
process_graphs=collection_backend_constraints
)
backend_candidates = [b for b in backend_candidates if all(c(b) for c in conditions)]
if collection_backend_constraints:
conditions = self._catalog.generate_backend_constraint_callables(
process_graphs=collection_backend_constraints
)
backend_candidates = [
b for b in backend_candidates if all(c(b) for c in conditions)
]

if processes:
process_candidates = self._get_backend_candidates_for_processes(processes)
backend_candidates = [
b for b in backend_candidates if b in process_candidates
]

if len(backend_candidates) > 1:
if len(backend_candidates) > 1:
# TODO #42 Check `/validation` instead of naively picking first one?
_log.warning(
f"Multiple back-end candidates {backend_candidates} for collections {collections}."
Expand Down
37 changes: 37 additions & 0 deletions src/openeo_aggregator/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,40 @@ def build_capabilities(
)

return capabilities


class MetadataBuilder:
"""Helper for building openEO/STAC-style metadata dictionaries"""

def collection(self, id="S2", *, license="proprietary") -> dict:
"""Build collection metadata"""
return {
"id": id,
"license": license,
"stac_version": "1.0.0",
"description": id,
"extent": {
"spatial": {"bbox": [[2, 50, 5, 55]]},
"temporal": {"interval": [["2017-01-01T00:00:00Z", None]]},
},
"links": [
{
"rel": "license",
"href": "https://oeoa.test/licence",
}
],
}

def collections(self, *args) -> dict:
"""Build `GET /collections` metadata"""
collections = []
for arg in args:
if isinstance(arg, str):
collection = self.collection(id=arg)
elif isinstance(arg, dict):
collection = self.collection(**arg)
else:
raise ValueError(arg)
collections.append(collection)

return {"collections": collections, "links": []}
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
MultiBackendConnection,
)
from openeo_aggregator.config import AggregatorConfig
from openeo_aggregator.testing import DummyKazooClient, build_capabilities
from openeo_aggregator.testing import (
DummyKazooClient,
MetadataBuilder,
build_capabilities,
)


@pytest.fixture
Expand Down Expand Up @@ -179,3 +183,8 @@ def catalog(multi_backend_connection, config) -> AggregatorCollectionCatalog:
backends=multi_backend_connection,
config=config
)


@pytest.fixture
def bldr() -> MetadataBuilder:
return MetadataBuilder()
79 changes: 79 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,85 @@ def b1_post_result(request: requests.Request, context):
assert res.json == 111
assert (b1_mock.call_count, b2_mock.call_count) == (1, 0)

@pytest.mark.parametrize(
["process_id", "call_counts", "expected_warnings"],
[
(
"blargh",
(1, 0),
[
RegexMatcher(
"Multiple back-end candidates.*Naively picking first one"
)
],
),
("wibble", (1, 0), []),
("snorfle", (0, 1), []),
(
"frobnicate",
(1, 0),
[
RegexMatcher("Skipping unknown process 'frobnicate'"),
RegexMatcher(
"Multiple back-end candidates.*Naively picking first one"
),
],
),
],
)
def test_result_backend_by_process(
self,
api100,
requests_mock,
backend1,
backend2,
process_id,
call_counts,
caplog,
bldr,
expected_warnings,
):
requests_mock.get(backend1 + "/collections", json=bldr.collections("S2"))
common_processes = [{"id": "load_collection"}, {"id": "blargh"}]
requests_mock.get(
backend1 + "/processes",
json={"processes": common_processes + [{"id": "wibble"}]},
)
requests_mock.get(backend2 + "/collections", json=bldr.collections("S2"))
requests_mock.get(
backend2 + "/processes",
json={"processes": common_processes + [{"id": "snorfle"}]},
)

def post_result(request: requests.Request, context):
assert (
request.headers["Authorization"]
== TEST_USER_AUTH_HEADER["Authorization"]
)
assert request.json()["process"]["process_graph"] == pg
context.headers["Content-Type"] = "application/json"
return 123

b1_mock = requests_mock.post(backend1 + "/result", json=post_result)
b2_mock = requests_mock.post(backend2 + "/result", json=post_result)
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
pg = {
"lc": {
"process_id": "load_collection",
"arguments": {"id": "S2"},
},
"process": {
"process_id": process_id,
"arguments": {"data": {"from_node": "lc"}, "factor": 5},
"result": True,
},
}
request = {"process": {"process_graph": pg}}
res = api100.post("/result", json=request).assert_status_code(200)
assert res.json == 123
assert (b1_mock.call_count, b2_mock.call_count) == call_counts
assert caplog.messages == expected_warnings


class TestBatchJobs:

Expand Down

0 comments on commit f761060

Please sign in to comment.