diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cbdc6cf..2e0adde5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,13 @@ and start a new "In Progress" section above it. ## In progress + +## 0.72.0 + - Start returning "OpenEO-Costs-experimental" header on synchronous processing responses +- Extract client credentials access token fetch logic from ElasticJobRegistry + into `ClientCredentialsAccessTokenHelper` to make it reusable (e.g. for ETL API as well) + ([Open-EO/openeo-geopyspark-driver#531](https://github.com/Open-EO/openeo-geopyspark-driver/issues/531)) ## 0.71.0 diff --git a/openeo_driver/_version.py b/openeo_driver/_version.py index 5c0aeeb2..e6bc0395 100644 --- a/openeo_driver/_version.py +++ b/openeo_driver/_version.py @@ -1 +1 @@ -__version__ = "0.71.0a1" +__version__ = "0.72.0a1" diff --git a/openeo_driver/jobregistry.py b/openeo_driver/jobregistry.py index 50662bd7..85b43d9e 100644 --- a/openeo_driver/jobregistry.py +++ b/openeo_driver/jobregistry.py @@ -1,6 +1,5 @@ import argparse import contextlib -import datetime as dt import json import logging import os @@ -9,17 +8,15 @@ import time import typing from decimal import Decimal -from typing import Any, Dict, List, NamedTuple, Optional, Union, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union import requests -from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo from openeo.rest.connection import url_join from openeo.util import TimingLogger, repr_truncate, rfc3339 import openeo_driver._version -from openeo_driver.datastructs import secretive_repr from openeo_driver.errors import InternalException, JobNotFoundException -from openeo_driver.util.caching import TtlCache +from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper from openeo_driver.util.logging import just_log_exceptions from openeo_driver.utils import generate_unique_id @@ -167,41 +164,35 @@ def from_response(cls, response: requests.Response) -> "EjrHttpError": ) -class ElasticJobRegistryCredentials(NamedTuple): - """Container of Elastic Job Registry related credentials.""" - - oidc_issuer: str - client_id: str - client_secret: str - __repr__ = __str__ = secretive_repr() - - @classmethod - def from_mapping(cls, data: typing.Mapping, *, strict: bool = True) -> Union["ElasticJobRegistryCredentials", None]: - """Build from mapping/dict/config""" - args = {"oidc_issuer", "client_id", "client_secret"} - try: - return cls(**{a: data[a] for a in args}) - except KeyError: - if strict: - missing = args.difference(data.keys()) - raise EjrError(f"Failed building {cls.__name__} from mapping: missing {missing!r}") from None - +class ElasticJobRegistryCredentials(ClientCredentials): + # Legacy alias/wrapper for ClientCredentials+get_ejr_credentials_from_env + # TODO remove when unused @classmethod def from_env( cls, env: Optional[typing.Mapping] = None, *, strict: bool = True - ) -> Union["ElasticJobRegistryCredentials", None]: - env = env or os.environ - env_var_mapping = { - "oidc_issuer": "OPENEO_EJR_OIDC_ISSUER", - "client_id": "OPENEO_EJR_OIDC_CLIENT_ID", - "client_secret": "OPENEO_EJR_OIDC_CLIENT_SECRET", - } - try: - return cls(**{a: env[e] for a, e in env_var_mapping.items()}) - except KeyError: - if strict: - missing = set(env_var_mapping.values()).difference(env.keys()) - raise EjrError(f"Failed building {cls.__name__} from env: missing {missing!r}") from None + ) -> Union[ClientCredentials, None]: + return get_ejr_credentials_from_env(env=env, strict=strict) + + +def get_ejr_credentials_from_env( + env: Optional[typing.Mapping] = None, *, strict: bool = True +) -> Union[ClientCredentials, None]: + # TODO only really used in openeo-geopyspark-driver atm + # TODO Generalize this functionality (map env vars to NamedTuple) in some way? + env = env or os.environ + env_var_mapping = { + "oidc_issuer": "OPENEO_EJR_OIDC_ISSUER", + "client_id": "OPENEO_EJR_OIDC_CLIENT_ID", + "client_secret": "OPENEO_EJR_OIDC_CLIENT_SECRET", + } + try: + kwargs = {a: env[e] for a, e in env_var_mapping.items()} + except KeyError: + if strict: + missing = set(env_var_mapping.values()).difference(env.keys()) + raise EjrError(f"Failed building {ClientCredentials.__name__} from env: missing {missing!r}") from None + else: + return ClientCredentials(**kwargs) class ElasticJobRegistry(JobRegistryInterface): @@ -225,8 +216,7 @@ def __init__( self.logger.info(f"Creating ElasticJobRegistry with {backend_id=} and {api_url=}") self._backend_id: Optional[str] = backend_id self._api_url = api_url - self._authenticator: Optional[OidcClientCredentialsAuthenticator] = None - self._cache = TtlCache(default_ttl=60 * 60) + self._access_token_helper = ClientCredentialsAccessTokenHelper(session=session) if session: self._session = session @@ -260,34 +250,9 @@ def backend_id(self) -> str: assert self._backend_id return self._backend_id - def setup_auth_oidc_client_credentials( - self, credentials: ElasticJobRegistryCredentials - ) -> None: + def setup_auth_oidc_client_credentials(self, credentials: ClientCredentials) -> None: """Set up OIDC client credentials authentication.""" - self.logger.info( - f"Setting up EJR OIDC Client Credentials Authentication with {credentials.client_id=}, {credentials.oidc_issuer=}, {len(credentials.client_secret)=}" - ) - oidc_provider = OidcProviderInfo( - issuer=credentials.oidc_issuer, requests_session=self._session - ) - client_info = OidcClientInfo( - client_id=credentials.client_id, - provider=oidc_provider, - client_secret=credentials.client_secret, - ) - self._authenticator = OidcClientCredentialsAuthenticator( - client_info=client_info, requests_session=self._session - ) - - def _get_access_token(self) -> str: - if not self._authenticator: - raise EjrError("No authentication set up") - with TimingLogger( - title=f"Requesting EJR OIDC access_token ({self._authenticator.__class__.__name__})", - logger=self.logger.info, - ): - tokens = self._authenticator.get_tokens() - return tokens.access_token + self._access_token_helper.setup_credentials(credentials) def _do_request( self, @@ -301,12 +266,7 @@ def _do_request( with TimingLogger(logger=self.logger.debug, title=f"EJR Request `{method} {path}`"): headers = {} if use_auth: - access_token = self._cache.get_or_call( - key="api_access_token", - callback=self._get_access_token, - # TODO: finetune/optimize caching TTL? Detect TTl/expiry from JWT access token itself? - ttl=30 * 60, - ) + access_token = self._access_token_helper.get_access_token() headers["Authorization"] = f"Bearer {access_token}" url = url_join(self._api_url, path) @@ -672,7 +632,7 @@ def _get_job_registry( " through environment variable `VAULT_TOKEN`" " or local file `~/.vault-token` (e.g. created with `vault login -method=ldap username=john`)." ) - credentials = ElasticJobRegistryCredentials.from_mapping(secret["data"]["data"]) + credentials = ClientCredentials.from_mapping(secret["data"]["data"]) ejr.setup_auth_oidc_client_credentials(credentials=credentials) return ejr diff --git a/openeo_driver/util/auth.py b/openeo_driver/util/auth.py new file mode 100644 index 00000000..d685a882 --- /dev/null +++ b/openeo_driver/util/auth.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +import time +from typing import Mapping, NamedTuple, Optional, Union + +import requests +from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo + +from openeo_driver.datastructs import secretive_repr + +_log = logging.getLogger(__name__) + + +class ClientCredentials(NamedTuple): + """ + Necessary bits for doing OIDC client credentials flow: + issuer URL, client id and secret. + """ + + oidc_issuer: str + client_id: str + client_secret: str + __repr__ = __str__ = secretive_repr() + + @classmethod + def from_mapping(cls, data: Mapping, *, strict: bool = True) -> Union[ClientCredentials, None]: + """Build from mapping/dict/config""" + keys = {"oidc_issuer", "client_id", "client_secret"} + try: + kwargs = {k: data[k] for k in keys} + except KeyError: + if strict: + missing = keys.difference(data.keys()) + raise ValueError(f"Failed building {cls.__name__} from mapping: missing {missing!r}") from None + else: + return cls(**kwargs) + + +class _AccessTokenCache(NamedTuple): + access_token: str + expires_at: float + + +class ClientCredentialsAccessTokenHelper: + """ + Helper to get OIDC access tokens using client credentials flow, e.g. to interact with an API (like EJR, ETL, ...) + Caches access token too. + + Usage: + - add an `OidcClientCredentialsHelper` instance to your class (e.g. in __init__) + - call `setup_credentials()` with `ClientCredentials` instance (or do this directly from __init__) + - call `get_access_token()` to get an access token where necessary + """ + + __slots__ = ("_authenticator", "_session", "_cache", "_default_ttl") + + def __init__( + self, + *, + credentials: Optional[ClientCredentials] = None, + session: Optional[requests.Session] = None, + default_ttl: float = 20 * 60, + ): + self._session = session + self._authenticator: Optional[OidcClientCredentialsAuthenticator] = None + self._cache = _AccessTokenCache("", 0) + self._default_ttl = default_ttl + + if credentials: + self.setup_credentials(credentials) + + def setup_credentials(self, credentials: ClientCredentials) -> None: + """ + Set up an `OidcClientCredentialsAuthenticator` + (that allows to fetch access tokens) + using the given client credentials and OIDC issuer configuration. + """ + # TODO: eliminate need for this separate `setup` and just do it always from `__init__`? + self._cache = _AccessTokenCache("", 0) + _log.debug(f"Setting up {self.__class__.__name__} with {credentials!r}") + oidc_provider = OidcProviderInfo( + issuer=credentials.oidc_issuer, + requests_session=self._session, + ) + client_info = OidcClientInfo( + client_id=credentials.client_id, + provider=oidc_provider, + client_secret=credentials.client_secret, + ) + self._authenticator = OidcClientCredentialsAuthenticator( + client_info=client_info, requests_session=self._session + ) + + def _get_access_token(self) -> str: + """Get an access token using the configured authenticator.""" + if not self._authenticator: + raise RuntimeError("No authentication set up") + _log.debug(f"{self.__class__.__name__} getting access token") + tokens = self._authenticator.get_tokens() + return tokens.access_token + + def get_access_token(self) -> str: + """Get an access token using the configured authenticator.""" + if time.time() > self._cache.expires_at: + access_token = self._get_access_token() + # TODO: get expiry from access token itself? + self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl) + return self._cache.access_token diff --git a/setup.py b/setup.py index 6d20212d..ae87a505 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "flask", "werkzeug>=1.0.1,<2.3.0", # https://github.com/Open-EO/openeo-python-driver/issues/187 "requests>=2.28.0", - "openeo>=0.24.0.a2.dev", + "openeo>=0.24.0.a3.dev", "openeo_processes==0.0.4", # 0.0.4 is special build/release, also see https://github.com/Open-EO/openeo-python-driver/issues/152 "gunicorn>=20.0.1", "numpy>=1.22.0", diff --git a/tests/test_jobregistry.py b/tests/test_jobregistry.py index 15a57a00..cbda5078 100644 --- a/tests/test_jobregistry.py +++ b/tests/test_jobregistry.py @@ -1,7 +1,5 @@ import logging -import urllib.parse from typing import Callable, List, Optional, Sequence, Union -from unittest import mock import pytest import requests @@ -16,9 +14,10 @@ EjrError, EjrHttpError, ElasticJobRegistry, - ElasticJobRegistryCredentials, + get_ejr_credentials_from_env, ) from openeo_driver.testing import DictSubSet, IgnoreOrder, ListSubSet, RegexMatcher, caplog_with_custom_formatter +from openeo_driver.util.auth import ClientCredentials DUMMY_PROCESS = { "summary": "calculate 3+5, please", @@ -32,56 +31,24 @@ } -class TestElasticJobRegistryCredentials: - def test_basic(self): - creds = ElasticJobRegistryCredentials( - oidc_issuer="https://oidc.test/", client_id="c123", client_secret="@#$" - ) - assert creds.oidc_issuer == "https://oidc.test/" - assert creds.client_id == "c123" - assert creds.client_secret == "@#$" - assert creds == ("https://oidc.test/", "c123", "@#$") - - def test_repr(self): - creds = ElasticJobRegistryCredentials( - oidc_issuer="https://oidc.test/", client_id="c123", client_secret="@#$" - ) - expected = "ElasticJobRegistryCredentials(oidc_issuer='https://oidc.test/', client_id='c123', client_secret='***')" - assert repr(creds) == expected - assert str(creds) == expected +def test_get_ejr_credentials_from_env(monkeypatch): + monkeypatch.setenv("OPENEO_EJR_OIDC_ISSUER", "https://id.example") + monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_ID", "c-9876") + monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_SECRET", "!@#$%%") + creds = get_ejr_credentials_from_env() + assert creds == ("https://id.example", "c-9876", "!@#$%%") - def test_get_from_mapping(self): - creds = ElasticJobRegistryCredentials.from_mapping( - {"oidc_issuer": "https://oidc.test/", "client_id": "c456789", "client_secret": "s3cr3t"}, - ) - assert creds == ("https://oidc.test/", "c456789", "s3cr3t") - def test_get_from_mapping_strictness(self): - data = {"oidc_issuer": "https://oidc.test/", "client_id": "c456789"} - with pytest.raises( - EjrError, match="Failed building ElasticJobRegistryCredentials from mapping: missing {'client_secret'}" - ): - _ = ElasticJobRegistryCredentials.from_mapping(data) - creds = ElasticJobRegistryCredentials.from_mapping(data, strict=False) - assert creds is None - - def test_get_from_env(self, monkeypatch): - monkeypatch.setenv("OPENEO_EJR_OIDC_ISSUER", "https://id.example") - monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_ID", "c-9876") - monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_SECRET", "!@#$%%") - creds = ElasticJobRegistryCredentials.from_env() - assert creds == ("https://id.example", "c-9876", "!@#$%%") - - def test_get_from_env_strictness(self, monkeypatch): - monkeypatch.setenv("OPENEO_EJR_OIDC_ISSUER", "https://id.example") - monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_ID", "c-9876") - with pytest.raises( - EjrError, - match="Failed building ElasticJobRegistryCredentials from env: missing {'OPENEO_EJR_OIDC_CLIENT_SECRET'}", - ): - _ = ElasticJobRegistryCredentials.from_env() - creds = ElasticJobRegistryCredentials.from_env(strict=False) - assert creds is None +def test_get_ejr_credentials_from_env_strictness(monkeypatch): + monkeypatch.setenv("OPENEO_EJR_OIDC_ISSUER", "https://id.example") + monkeypatch.setenv("OPENEO_EJR_OIDC_CLIENT_ID", "c-9876") + with pytest.raises( + EjrError, + match="Failed building ClientCredentials from env: missing {'OPENEO_EJR_OIDC_CLIENT_SECRET'}", + ): + _ = get_ejr_credentials_from_env() + creds = get_ejr_credentials_from_env(strict=False) + assert creds is None class TestElasticJobRegistry: @@ -109,7 +76,7 @@ def oidc_mock(self, requests_mock) -> OidcMock: def ejr(self, oidc_mock) -> ElasticJobRegistry: """ElasticJobRegistry set up with authentication""" ejr = ElasticJobRegistry(api_url=self.EJR_API_URL, backend_id="unittests") - credentials = ElasticJobRegistryCredentials( + credentials = ClientCredentials( oidc_issuer=self.OIDC_CLIENT_INFO["oidc_issuer"], client_id=self.OIDC_CLIENT_INFO["client_id"], client_secret=self.OIDC_CLIENT_INFO["client_secret"], @@ -243,7 +210,7 @@ def post_token(request, context): ejr = ElasticJobRegistry( api_url=self.EJR_API_URL, backend_id="unittests", session=session ) - credentials = ElasticJobRegistryCredentials( + credentials = ClientCredentials( oidc_issuer=self.OIDC_CLIENT_INFO["oidc_issuer"], client_id=self.OIDC_CLIENT_INFO["client_id"], client_secret=self.OIDC_CLIENT_INFO["client_secret"], diff --git a/tests/util/test_auth.py b/tests/util/test_auth.py new file mode 100644 index 00000000..d813151f --- /dev/null +++ b/tests/util/test_auth.py @@ -0,0 +1,78 @@ +import logging + +import pytest +from openeo.rest.auth.testing import OidcMock + +from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper + + +class TestClientCredentials: + def test_basic(self): + creds = ClientCredentials(oidc_issuer="https://oidc.test/", client_id="c123", client_secret="@#$") + assert creds.oidc_issuer == "https://oidc.test/" + assert creds.client_id == "c123" + assert creds.client_secret == "@#$" + assert creds == ("https://oidc.test/", "c123", "@#$") + + def test_repr(self): + creds = ClientCredentials(oidc_issuer="https://oidc.test/", client_id="c123", client_secret="@#$") + expected = "ClientCredentials(oidc_issuer='https://oidc.test/', client_id='c123', client_secret='***')" + assert repr(creds) == expected + assert str(creds) == expected + + def test_get_from_mapping(self): + creds = ClientCredentials.from_mapping( + {"oidc_issuer": "https://oidc.test/", "client_id": "c456789", "client_secret": "s3cr3t"}, + ) + assert creds == ("https://oidc.test/", "c456789", "s3cr3t") + + def test_get_from_mapping_strictness(self): + data = {"oidc_issuer": "https://oidc.test/", "client_id": "c456789"} + with pytest.raises( + ValueError, match="Failed building ClientCredentials from mapping: missing {'client_secret'}" + ): + _ = ClientCredentials.from_mapping(data) + creds = ClientCredentials.from_mapping(data, strict=False) + assert creds is None + + +class TestClientCredentialsAccessTokenHelper: + @pytest.fixture + def credentials(self) -> ClientCredentials: + return ClientCredentials(oidc_issuer="https://oidc.test", client_id="client123", client_secret="s3cr3t") + + @pytest.fixture + def oidc_mock(self, requests_mock, credentials) -> OidcMock: + oidc_mock = OidcMock( + requests_mock=requests_mock, + oidc_issuer=credentials.oidc_issuer, + expected_grant_type="client_credentials", + expected_client_id=credentials.client_id, + expected_fields={"client_secret": credentials.client_secret, "scope": "openid"}, + ) + return oidc_mock + + def test_basic(self, credentials, oidc_mock: OidcMock): + helper = ClientCredentialsAccessTokenHelper(credentials=credentials) + assert helper.get_access_token() == oidc_mock.state["access_token"] + + def test_caching(self, credentials, oidc_mock: OidcMock): + helper = ClientCredentialsAccessTokenHelper(credentials=credentials) + assert oidc_mock.mocks["token_endpoint"].call_count == 0 + assert helper.get_access_token() == oidc_mock.state["access_token"] + assert oidc_mock.mocks["token_endpoint"].call_count == 1 + assert helper.get_access_token() == oidc_mock.state["access_token"] + assert oidc_mock.mocks["token_endpoint"].call_count == 1 + + def test_secret_logging(self, credentials, oidc_mock: OidcMock, caplog): + """Check that secret is not logged""" + caplog.set_level(logging.DEBUG) + helper = ClientCredentialsAccessTokenHelper(credentials=credentials) + assert helper.get_access_token() == oidc_mock.state["access_token"] + (setup_log,) = [ + log for log in caplog.messages if log.startswith("Setting up ClientCredentialsAccessTokenHelper") + ] + assert ( + setup_log + == "Setting up ClientCredentialsAccessTokenHelper with ClientCredentials(oidc_issuer='https://oidc.test', client_id='client123', client_secret='***')" + )