Skip to content

Commit

Permalink
Issue #100 add process based backend candidate selection
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 4, 2023
1 parent a17bc23 commit eb0d5c1
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/

- Initial, experimental (client-side) proof of concept for cross-backend processing
([#95](https://github.com/Open-EO/openeo-aggregator/issues/95))
- Consider process availability in backend selection
([#100](https://github.com/Open-EO/openeo-aggregator/issues/100))

### Changed

Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.0a1"
__version__ = "0.8.1a1"
79 changes: 66 additions & 13 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,20 @@
import pathlib
import re
import time
import typing
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)

import flask
import openeo_driver.util.view_helpers
Expand Down Expand Up @@ -198,7 +210,7 @@ def evaluate(backend_id, pg):

def get_backend_candidates_for_collections(self, collections: Iterable[str]) -> List[str]:
"""
Get best backend id providing all given collections
Get backend ids providing all given collections
:param collections: list/set of collection ids
:return:
"""
Expand All @@ -217,6 +229,7 @@ def get_backend_candidates_for_collections(self, collections: Iterable[str]) ->
elif len(backend_combos) == 1:
backend_candidates = list(backend_combos.pop())
else:
# TODO: order preservation is not necessary (anymore), which allows to simplify all this logic.
# Search for common backends in all sets (and preserve order)
intersection = functools.reduce(lambda a, b: [x for x in a if x in b], backend_combos)
if intersection:
Expand Down Expand Up @@ -324,15 +337,18 @@ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> Pr
# TODO: only check for mismatch in major version?
_log.warning(f"API mismatch: requested {api_version} != upstream {self.backends.api_version}")

combined_processes = self._memoizer.get_or_call(
key=("all", str(api_version)),
callback=self._get_merged_process_metadata,
)
combined_processes = self.get_merged_process_metadata()
process_registry = ProcessRegistry()
for pid, spec in combined_processes.items():
process_registry.add_spec(spec=spec)
return process_registry

def get_merged_process_metadata(self) -> Dict[str, dict]:
return self._memoizer.get_or_call(
key=("all", str(self.backends.api_version)),
callback=self._get_merged_process_metadata,
)

def _get_merged_process_metadata(self) -> dict:
processes_per_backend = {}
for con in self.backends:
Expand All @@ -347,6 +363,30 @@ def _get_merged_process_metadata(self) -> dict:
)
return combined_processes

def _get_backend_candidates_for_processes(
self, processes: typing.Collection[str]
) -> Union[List[str], None]:
"""
Get backend ids providing all given processes
:param processes: collection process ids
:return:
"""
processes = set(processes)
process_metadata = self.get_merged_process_metadata()
candidates: Union[Set[str], None] = None
for pid in processes:
if pid in process_metadata:
backends = process_metadata[pid][STAC_PROPERTY_FEDERATION_BACKENDS]
if candidates is None:
candidates = set(backends)
else:
candidates = candidates.intersection(backends)
else:
_log.warning(
f"Skipping unknown process {pid!r} in `_get_backend_candidates_for_processes`"
)
return candidates

def get_backend_for_process_graph(
self, process_graph: dict, api_version: str, job_options: Optional[dict] = None
) -> str:
Expand All @@ -367,12 +407,13 @@ def get_backend_for_process_graph(
)
return bid

# TODO: also check used processes?
collections = set()
collection_backend_constraints = []
processes = set()
try:
for pg_node in process_graph.values():
process_id = pg_node["process_id"]
processes.add(process_id)
arguments = pg_node["arguments"]
if process_id == "load_collection":
collections.add(arguments["id"])
Expand Down Expand Up @@ -414,13 +455,25 @@ def get_backend_for_process_graph(
collection_candidates = self._catalog.get_backend_candidates_for_collections(collections=collections)
backend_candidates = [b for b in backend_candidates if b in collection_candidates]

if collection_backend_constraints:
conditions = self._catalog.generate_backend_constraint_callables(
process_graphs=collection_backend_constraints
)
backend_candidates = [b for b in backend_candidates if all(c(b) for c in conditions)]
if collection_backend_constraints:
conditions = self._catalog.generate_backend_constraint_callables(
process_graphs=collection_backend_constraints
)
backend_candidates = [
b for b in backend_candidates if all(c(b) for c in conditions)
]

if processes:
process_candidates = self._get_backend_candidates_for_processes(processes)
if process_candidates:
backend_candidates = [
b for b in backend_candidates if b in process_candidates
]
else:
# TODO: make this an exception like we do with collections? (BackendLookupFailureException)
_log.warning(f"No process based backend candidates ({processes=})")

if len(backend_candidates) > 1:
if len(backend_candidates) > 1:
# TODO #42 Check `/validation` instead of naively picking first one?
_log.warning(
f"Multiple back-end candidates {backend_candidates} for collections {collections}."
Expand Down
69 changes: 68 additions & 1 deletion src/openeo_aggregator/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import pathlib
import time
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union
from unittest import mock

import kazoo
Expand Down Expand Up @@ -203,3 +203,70 @@ def build_capabilities(
)

return capabilities


class MetadataBuilder:
"""Helper for building openEO/STAC-style metadata dictionaries"""

def collection(self, id="S2", *, license="proprietary") -> dict:
"""Build collection metadata"""
return {
"id": id,
"license": license,
"stac_version": "1.0.0",
"description": id,
"extent": {
"spatial": {"bbox": [[2, 50, 5, 55]]},
"temporal": {"interval": [["2017-01-01T00:00:00Z", None]]},
},
"links": [
{
"rel": "license",
"href": "https://oeoa.test/licence",
}
],
}

def collections(self, *args) -> dict:
"""Build `GET /collections` metadata"""
collections = []
for arg in args:
if isinstance(arg, str):
collection = self.collection(id=arg)
elif isinstance(arg, dict):
collection = self.collection(**arg)
else:
raise ValueError(arg)
collections.append(collection)

return {"collections": collections, "links": []}

def process(
self,
id,
*,
parameters: Optional[List[dict]] = None,
returns: Optional[dict] = None,
) -> dict:
"""Build process metadata"""
return {
"id": id,
"description": id,
"parameters": parameters or [],
"returns": returns
or {"schema": {"type": "object", "subtype": "raster-cube"}},
}

def processes(self, *args) -> dict:
"""Build `GET /processes` metadata"""
processes = []
for arg in args:
if isinstance(arg, str):
process = self.collection(id=arg)
elif isinstance(arg, dict):
process = self.collection(**arg)
else:
raise ValueError(arg)
processes.append(process)

return {"processes": processes, "links": []}
46 changes: 37 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,50 @@
MultiBackendConnection,
)
from openeo_aggregator.config import AggregatorConfig
from openeo_aggregator.testing import DummyKazooClient, build_capabilities
from openeo_aggregator.testing import (
DummyKazooClient,
MetadataBuilder,
build_capabilities,
)

_DEFAULT_PROCESSES = [
"load_collection",
"load_result",
"save_result",
"merge_cubes",
"mask",
"load_ml_model",
"add",
"large",
]


@pytest.fixture
def backend1(requests_mock) -> str:
def backend1(requests_mock, bldr) -> str:
domain = "https://b1.test/v1"
# TODO: how to work with different API versions?
requests_mock.get(domain + "/", json=build_capabilities())
requests_mock.get(domain + "/credentials/oidc", json={"providers": [
{"id": "egi", "issuer": "https://egi.test", "title": "EGI"}
]})
requests_mock.get(
domain + "/credentials/oidc",
json={
"providers": [{"id": "egi", "issuer": "https://egi.test", "title": "EGI"}]
},
)
requests_mock.get(domain + "/processes", json=bldr.processes(*_DEFAULT_PROCESSES))
return domain


@pytest.fixture
def backend2(requests_mock) -> str:
def backend2(requests_mock, bldr) -> str:
domain = "https://b2.test/v1"
requests_mock.get(domain + "/", json=build_capabilities())
requests_mock.get(domain + "/credentials/oidc", json={"providers": [
{"id": "egi", "issuer": "https://egi.test", "title": "EGI"}
]})
requests_mock.get(
domain + "/credentials/oidc",
json={
"providers": [{"id": "egi", "issuer": "https://egi.test", "title": "EGI"}]
},
)
requests_mock.get(domain + "/processes", json=bldr.processes(*_DEFAULT_PROCESSES))
return domain


Expand Down Expand Up @@ -179,3 +202,8 @@ def catalog(multi_backend_connection, config) -> AggregatorCollectionCatalog:
backends=multi_backend_connection,
config=config
)


@pytest.fixture
def bldr() -> MetadataBuilder:
return MetadataBuilder()
79 changes: 79 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,85 @@ def b1_post_result(request: requests.Request, context):
assert res.json == 111
assert (b1_mock.call_count, b2_mock.call_count) == (1, 0)

@pytest.mark.parametrize(
["process_id", "call_counts", "expected_warnings"],
[
(
"blargh",
(1, 0),
[
RegexMatcher(
"Multiple back-end candidates.*Naively picking first one"
)
],
),
("wibble", (1, 0), []),
("snorfle", (0, 1), []),
(
"frobnicate",
(1, 0),
[
RegexMatcher("Skipping unknown process 'frobnicate'"),
RegexMatcher(
"Multiple back-end candidates.*Naively picking first one"
),
],
),
],
)
def test_result_backend_by_process(
self,
api100,
requests_mock,
backend1,
backend2,
process_id,
call_counts,
caplog,
bldr,
expected_warnings,
):
requests_mock.get(backend1 + "/collections", json=bldr.collections("S2"))
common_processes = [{"id": "load_collection"}, {"id": "blargh"}]
requests_mock.get(
backend1 + "/processes",
json={"processes": common_processes + [{"id": "wibble"}]},
)
requests_mock.get(backend2 + "/collections", json=bldr.collections("S2"))
requests_mock.get(
backend2 + "/processes",
json={"processes": common_processes + [{"id": "snorfle"}]},
)

def post_result(request: requests.Request, context):
assert (
request.headers["Authorization"]
== TEST_USER_AUTH_HEADER["Authorization"]
)
assert request.json()["process"]["process_graph"] == pg
context.headers["Content-Type"] = "application/json"
return 123

b1_mock = requests_mock.post(backend1 + "/result", json=post_result)
b2_mock = requests_mock.post(backend2 + "/result", json=post_result)
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
pg = {
"lc": {
"process_id": "load_collection",
"arguments": {"id": "S2"},
},
"process": {
"process_id": process_id,
"arguments": {"data": {"from_node": "lc"}, "factor": 5},
"result": True,
},
}
request = {"process": {"process_graph": pg}}
res = api100.post("/result", json=request).assert_status_code(200)
assert res.json == 123
assert (b1_mock.call_count, b2_mock.call_count) == call_counts
assert caplog.messages == expected_warnings


class TestBatchJobs:

Expand Down

0 comments on commit eb0d5c1

Please sign in to comment.