Skip to content

Commit

Permalink
Merge branch 'master' into 5-federation-of-collections
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 16, 2022
2 parents 69a3225 + f334bcb commit c969aff
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 178 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make sure user id (prefix) is logged in JSON logs
- Updated (generous fallback) "FreeTier" user role to 30DayTrial (more strict)
- Use EODC dev instance in aggregator dev config
- Update EGI issuer URL to new Keycloak one (keep old provider under "egi-old")
- Update EGI issuer URL to new Keycloak one (keep old provider under "egi-legacy")

### Fixed

- Properly rewrite model id in `load_ml_model` ([#70](https://github.com/Open-EO/openeo-aggregator/issues/70))

## [0.4.x]

### Added
Expand Down
57 changes: 25 additions & 32 deletions conf/aggregator.dev.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openeo_aggregator.config import AggregatorConfig
from openeo_driver.users.oidc import OidcProvider

DEFAULT_OIDC_CLIENT_EGI = {
_DEFAULT_OIDC_CLIENT_EGI = {
"id": "openeo-platform-default-client",
"grant_types": [
"authorization_code+pkce",
Expand All @@ -14,43 +14,36 @@
"https://editor.openeo.org",
]
}

_DEFAULT_EGI_SCOPES = [
"openid",
"email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
]

configured_oidc_providers = [
OidcProvider(
id="egi",
issuer="https://aai.egi.eu/auth/realms/egi/",
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
title="EGI Check-in",
default_client=DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[DEFAULT_OIDC_CLIENT_EGI],
issuer="https://aai.egi.eu/auth/realms/egi/",
scopes=_DEFAULT_EGI_SCOPES,
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
OidcProvider(
id="egi-old",
issuer="https://aai.egi.eu/oidc/", # TODO: remove old EGI provider refs (issuer https://aai.egi.eu/oidc/)
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
id="egi-legacy",
title="EGI Check-in (legacy)",
default_client=DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[DEFAULT_OIDC_CLIENT_EGI],
issuer="https://aai.egi.eu/oidc/", # TODO: remove old EGI provider refs (issuer https://aai.egi.eu/oidc/)
scopes=_DEFAULT_EGI_SCOPES,
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
OidcProvider(
id="egi-dev",
title="EGI Check-in (dev)",
issuer="https://aai-dev.egi.eu/auth/realms/egi/",
scopes=_DEFAULT_EGI_SCOPES,
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
# OidcProvider(
# id="egi-dev",
# issuer="https://aai-dev.egi.eu/oidc/",
# scopes=[
# "openid", "email",
# "eduperson_entitlement",
# "eduperson_scoped_affiliation",
# ],
# title="EGI Check-in (dev)",
# default_client=_DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
# default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
# ),
]

config = AggregatorConfig(
Expand All @@ -59,9 +52,9 @@
"vito": "https://openeo-dev.vito.be/openeo/1.0/",
"eodc": "https://openeo-dev.eodc.eu/v1.0/",
# internal version of https://openeo.creo.vito.be/openeo/1.0/
"creo": "https://openeo.creo.vgt.vito.be/openeo/1.0",
"creo": "https://openeo-dev.creo.vito.be/openeo/1.0",
# Sentinel Hub OpenEO by Sinergise
"sentinelhub": "https://w0j9yieg9l.execute-api.eu-central-1.amazonaws.com/testing",
"sentinelhub": "https://openeo.sentinel-hub.com/production/",
},
auth_entitlement_check={"oidc_issuer_whitelist": {
"https://aai.egi.eu/auth/realms/egi/",
Expand Down
33 changes: 15 additions & 18 deletions conf/aggregator.prod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openeo_aggregator.config import AggregatorConfig
from openeo_driver.users.oidc import OidcProvider

DEFAULT_OIDC_CLIENT_EGI = {
_DEFAULT_OIDC_CLIENT_EGI = {
"id": "openeo-platform-default-client",
"grant_types": [
"authorization_code+pkce",
Expand All @@ -15,30 +15,27 @@
]
}

_DEFAULT_EGI_SCOPES = [
"openid",
"email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
]

configured_oidc_providers = [
OidcProvider(
id="egi",
issuer="https://aai.egi.eu/auth/realms/egi/",
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
title="EGI Check-in",
default_client=DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[DEFAULT_OIDC_CLIENT_EGI],
issuer="https://aai.egi.eu/auth/realms/egi/",
scopes=_DEFAULT_EGI_SCOPES,
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
OidcProvider(
id="egi-old",
issuer="https://aai.egi.eu/oidc/", # TODO: remove old EGI provider refs (issuer https://aai.egi.eu/oidc/)
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
id="egi-legacy",
title="EGI Check-in (legacy)",
default_client=DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[DEFAULT_OIDC_CLIENT_EGI],
issuer="https://aai.egi.eu/oidc/", # TODO: remove old EGI provider refs (issuer https://aai.egi.eu/oidc/)
scopes=_DEFAULT_EGI_SCOPES,
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
]

Expand Down
40 changes: 32 additions & 8 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter
from openeo_aggregator.partitionedjobs.tracking import PartitionedJobConnection, PartitionedJobTracker
from openeo_aggregator.utils import TtlCache, MultiDictGetter, subdict, dict_merge
from openeo_aggregator.utils import TtlCache, MultiDictGetter, subdict, dict_merge, normalize_issuer_url
from openeo_driver.ProcessGraphDeserializer import SimpleProcessing
from openeo_driver.backend import OpenEoBackendImplementation, AbstractCollectionCatalog, LoadParameters, Processing, \
OidcProvider, BatchJobs, BatchJobMetadata
Expand Down Expand Up @@ -432,6 +432,10 @@ def get_backend_for_process_graph(self, process_graph: dict, api_version: str) -
aggregator_job_id=arguments["id"]
)
backend_candidates = [b for b in backend_candidates if b == job_backend_id]
elif process_id == "load_ml_model":
model_backend_id = self._process_load_ml_model(arguments)[0]
if model_backend_id:
backend_candidates = [b for b in backend_candidates if b == model_backend_id]
except Exception as e:
_log.error(f"Failed to parse process graph: {e!r}", exc_info=True)
raise ProcessGraphInvalidException()
Expand Down Expand Up @@ -497,13 +501,33 @@ def preprocess(node: Any) -> Any:
assert job_backend_id == backend_id, f"{job_backend_id} != {backend_id}"
# Create new load_result node dict with updated job id
return dict_merge(node, arguments=dict_merge(arguments, id=job_id))
if process_id == "load_ml_model":
model_id = self._process_load_ml_model(arguments, expected_backend=backend_id)[1]
if model_id:
return dict_merge(node, arguments=dict_merge(arguments, id=model_id))
return {k: preprocess(v) for k, v in node.items()}
elif isinstance(node, list):
return [preprocess(x) for x in node]
return node

return preprocess(process_graph)

def _process_load_ml_model(
self, arguments: dict, expected_backend: Optional[str] = None
) -> Tuple[Union[str, None], str]:
"""Handle load_ml_model: detect/strip backend_id from model_id if it is a job_id"""
model_id = arguments.get("id")
if model_id and not model_id.startswith("http"):
# TODO: load_ml_model's `id` could also be file path (see https://github.com/Open-EO/openeo-processes/issues/384)
job_id, job_backend_id = JobIdMapping.parse_aggregator_job_id(
backends=self.backends,
aggregator_job_id=model_id
)
if expected_backend and job_backend_id != expected_backend:
raise BackendLookupFailureException(f"{job_backend_id} != {expected_backend}")
return job_backend_id, job_id
return None, model_id


class AggregatorBatchJobs(BatchJobs):

Expand Down Expand Up @@ -735,6 +759,7 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
)
self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT, name="General")
self._backends.on_connections_change.add(self._cache.flush_all)
self._configured_oidc_providers: List[OidcProvider] = config.configured_oidc_providers
self._auth_entitlement_check: Union[bool, dict] = config.auth_entitlement_check

# Shorter HTTP cache TTL to adapt quicker to changed back-end configurations
Expand All @@ -743,11 +768,7 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
)

def oidc_providers(self) -> List[OidcProvider]:
# TODO: openeo-python-driver (HttpAuthHandler) currently does support changes in
# the set of oidc_providers ids (id mapping is statically established at startup time)
return self._cache.get_or_call(
key="oidc_providers", callback=self._backends.get_oidc_providers, log_on_miss=True,
)
return self._configured_oidc_providers

def file_formats(self) -> dict:
return self._cache.get_or_call(key="file_formats", callback=self._file_formats, log_on_miss=True)
Expand Down Expand Up @@ -777,10 +798,13 @@ def merge(formats: dict, to_add: dict):
def user_access_validation(self, user: User, request: flask.Request) -> User:
if self._auth_entitlement_check:
int_data = user.internal_auth_data
issuer_whitelist = self._auth_entitlement_check.get("oidc_issuer_whitelist", [])
issuer_whitelist = [
normalize_issuer_url(u)
for u in self._auth_entitlement_check.get("oidc_issuer_whitelist", [])
]
if not (
int_data["authentication_method"] == "OIDC"
and int_data["oidc_issuer"].rstrip("/").lower() in issuer_whitelist
and normalize_issuer_url(int_data["oidc_issuer"]) in issuer_whitelist
):
user_message = "An EGI account is required for using openEO Platform."
_log.warning(f"user_access_validation failure: %r %r", user_message, {
Expand Down
37 changes: 11 additions & 26 deletions src/openeo_aggregator/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(

self.default_timeout = default_timeout

def __repr__(self):
return f"<{type(self).__name__} {self.id}: {self._root_url}>"

def _get_auth(self) -> Union[None, OpenEoApiAuthBase]:
return None if self._auth_locked else self._auth

Expand All @@ -92,8 +95,7 @@ def _build_oidc_provider_map(self, configured_providers: List[OidcProvider]) ->
pid_map[agg_provider.id] = targets[0]
return pid_map

@property
def oidc_provider_map(self) -> Dict[str, str]:
def get_oidc_provider_map(self) -> Dict[str, str]:
return self._oidc_provider_map

def _get_bearer(self, request: flask.Request) -> str:
Expand All @@ -106,9 +108,13 @@ def _get_bearer(self, request: flask.Request) -> str:
return auth.partition("Bearer ")[2]
elif auth.startswith("Bearer oidc/"):
_, pid, token = auth.split("/")
if pid not in self._oidc_provider_map:
_log.warning(f"OIDC provider mapping failure: {pid} not in {self._oidc_provider_map}.")
backend_pid = self._oidc_provider_map.get(pid, pid)
try:
backend_pid = self._oidc_provider_map[pid]
except KeyError:
_log.error(f"Back-end {self} lacks OIDC provider support: {pid!r} not in {self._oidc_provider_map}.")
raise OpenEOApiException(
code="OidcSupportError", message=f"Back-end {self.id!r} does not support OIDC provider {pid!r}."
)
return f"oidc/{backend_pid}/{token}"
else:
raise AuthenticationSchemeInvalidException
Expand Down Expand Up @@ -289,27 +295,6 @@ def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[st
# TODO: customizable exception handling: skip, warn, re-raise?
yield con.id, res

def get_oidc_providers(self) -> List[OidcProvider]:
"""
Determine OIDC providers to use in aggregator (based on OIDC issuers supported by all backends)
and set up provider id mapping in the backend connections
:param configured_providers: OIDC providers dedicated/configured for the aggregator
:return: list of actual OIDC providers to use (configured for aggregator and supported by all backends)
"""
# Get intersection of aggregator OIDC provider ids
agg_pids_per_backend = [set(c.oidc_provider_map.keys()) for c in self.get_connections()]
intersection: Set[str] = functools.reduce((lambda x, y: x.intersection(y)), agg_pids_per_backend)
_log.debug(f"OIDC provider intersection: {intersection}")
if len(intersection) == 0:
_log.error(f"Emtpy OIDC provider intersection. Issuers per backend: {agg_pids_per_backend}")

# Take configured providers for common issuers.
agg_providers = [p for p in self._configured_oidc_providers if p.id in intersection]
_log.info(f"Actual aggregator OIDC providers: {agg_providers}")

return agg_providers


def streaming_flask_response(
backend_response: requests.Response,
Expand Down
4 changes: 4 additions & 0 deletions src/openeo_aggregator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,7 @@ def timestamp_to_rfc3339(timestamp: float) -> str:
"""Convert unix epoch timestamp to RFC3339 datetime string"""
dt = datetime.datetime.utcfromtimestamp(timestamp)
return rfc3339.datetime(dt)


def normalize_issuer_url(url: str) -> str:
return url.rstrip("/").lower()
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ def backend2(requests_mock) -> str:


@pytest.fixture
def configured_oidc_providers() -> List[OidcProvider]:
def main_test_oidc_issuer() -> str:
"""
Main OIDC issuer URL.
As a fixture to make it overridable with `pytest.mark.parametrize` for certain tests.
"""
return "https://egi.test"


@pytest.fixture
def configured_oidc_providers(main_test_oidc_issuer: str) -> List[OidcProvider]:
return [
OidcProvider(id="egi", issuer="https://egi.test", title="EGI"),
OidcProvider(id="egi", issuer=main_test_oidc_issuer, title="EGI"),
OidcProvider(id="x-agg", issuer="https://x.test", title="X (agg)"),
OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)"),
OidcProvider(id="z-agg", issuer="https://z.test", title="Z (agg)"),
Expand Down
38 changes: 4 additions & 34 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,15 @@
class TestAggregatorBackendImplementation:

def test_oidc_providers(self, multi_backend_connection, config, backend1, backend2, requests_mock):
requests_mock.get(backend1 + "/credentials/oidc", json={"providers": [
{"id": "x", "issuer": "https://x.test", "title": "X"},
{"id": "y", "issuer": "https://y.test", "title": "YY"},
]})
requests_mock.get(backend2 + "/credentials/oidc", json={"providers": [
{"id": "y", "issuer": "https://y.test", "title": "YY"},
{"id": "z", "issuer": "https://z.test", "title": "ZZZ"},
]})
implementation = AggregatorBackendImplementation(backends=multi_backend_connection, config=config)
providers = implementation.oidc_providers()
assert providers == [
OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")
OidcProvider(id='egi', issuer='https://egi.test', title='EGI'),
OidcProvider(id='x-agg', issuer='https://x.test', title='X (agg)'),
OidcProvider(id='y-agg', issuer='https://y.test', title='Y (agg)'),
OidcProvider(id='z-agg', issuer='https://z.test', title='Z (agg)'),
]

def test_oidc_providers_caching(self, multi_backend_connection, config, backend1, backend2, requests_mock):
m1 = requests_mock.get(backend1 + "/credentials/oidc", json={"providers": [
{"id": "x", "issuer": "https://x.test", "title": "X"},
{"id": "y", "issuer": "https://y.test", "title": "YY"},
]})
m2 = requests_mock.get(backend2 + "/credentials/oidc", json={"providers": [
{"id": "y", "issuer": "https://y.test", "title": "YY"},
{"id": "z", "issuer": "https://z.test", "title": "ZZZ"},
]})
implementation = AggregatorBackendImplementation(backends=multi_backend_connection, config=config)
assert (m1.call_count, m2.call_count) == (0, 0)
providers = implementation.oidc_providers()
assert providers == [OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")]
assert (m1.call_count, m2.call_count) == (1, 1)
providers = implementation.oidc_providers()
assert providers == [OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")]
assert (m1.call_count, m2.call_count) == (1, 1)

MultiBackendConnection._clock = itertools.count(time.time() + 1000).__next__
implementation._cache.flush_all()

providers = implementation.oidc_providers()
assert providers == [OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")]
assert (m1.call_count, m2.call_count) == (2, 2)

def test_file_formats_simple(self, multi_backend_connection, config, backend1, backend2, requests_mock):
just_geotiff = {
"input": {"GTiff": {"gis_data_types": ["raster"], "parameters": {}, "title": "GeoTiff"}},
Expand Down
Loading

0 comments on commit c969aff

Please sign in to comment.