Skip to content

Commit

Permalink
Extract client cred access token logic from ElasticJobRegistry
Browse files Browse the repository at this point in the history
to allow reuse (e.g. ETL API)
related to Open-EO/openeo-geopyspark-driver#531
  • Loading branch information
soxofaan committed Oct 20, 2023
1 parent 68b0c2b commit 1bc83b6
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 128 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ and start a new "In Progress" section above it.

## In progress


## 0.72.0

- Start returning "OpenEO-Costs-experimental" header on synchronous processing responses
- Extract client credentials access token fetch logic from ElasticJobRegistry
into `ClientCredentialsAccessTokenHelper` to make it reusable (e.g. for ETL API as well)
([Open-EO/openeo-geopyspark-driver#531](https://github.com/Open-EO/openeo-geopyspark-driver/issues/531))

## 0.71.0

Expand Down
2 changes: 1 addition & 1 deletion openeo_driver/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.71.0a1"
__version__ = "0.72.0a1"
106 changes: 33 additions & 73 deletions openeo_driver/jobregistry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import contextlib
import datetime as dt
import json
import logging
import os
Expand All @@ -9,17 +8,15 @@
import time
import typing
from decimal import Decimal
from typing import Any, Dict, List, NamedTuple, Optional, Union, Sequence
from typing import Any, Dict, List, Optional, Sequence, Union

import requests
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo
from openeo.rest.connection import url_join
from openeo.util import TimingLogger, repr_truncate, rfc3339

import openeo_driver._version
from openeo_driver.datastructs import secretive_repr
from openeo_driver.errors import InternalException, JobNotFoundException
from openeo_driver.util.caching import TtlCache
from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper
from openeo_driver.util.logging import just_log_exceptions
from openeo_driver.utils import generate_unique_id

Expand Down Expand Up @@ -167,41 +164,35 @@ def from_response(cls, response: requests.Response) -> "EjrHttpError":
)


class ElasticJobRegistryCredentials(NamedTuple):
"""Container of Elastic Job Registry related credentials."""

oidc_issuer: str
client_id: str
client_secret: str
__repr__ = __str__ = secretive_repr()

@classmethod
def from_mapping(cls, data: typing.Mapping, *, strict: bool = True) -> Union["ElasticJobRegistryCredentials", None]:
"""Build from mapping/dict/config"""
args = {"oidc_issuer", "client_id", "client_secret"}
try:
return cls(**{a: data[a] for a in args})
except KeyError:
if strict:
missing = args.difference(data.keys())
raise EjrError(f"Failed building {cls.__name__} from mapping: missing {missing!r}") from None

class ElasticJobRegistryCredentials(ClientCredentials):
# Legacy alias/wrapper for ClientCredentials+get_ejr_credentials_from_env
# TODO remove when unused
@classmethod
def from_env(
cls, env: Optional[typing.Mapping] = None, *, strict: bool = True
) -> Union["ElasticJobRegistryCredentials", None]:
env = env or os.environ
env_var_mapping = {
"oidc_issuer": "OPENEO_EJR_OIDC_ISSUER",
"client_id": "OPENEO_EJR_OIDC_CLIENT_ID",
"client_secret": "OPENEO_EJR_OIDC_CLIENT_SECRET",
}
try:
return cls(**{a: env[e] for a, e in env_var_mapping.items()})
except KeyError:
if strict:
missing = set(env_var_mapping.values()).difference(env.keys())
raise EjrError(f"Failed building {cls.__name__} from env: missing {missing!r}") from None
) -> Union[ClientCredentials, None]:
return get_ejr_credentials_from_env(env=env, strict=strict)


def get_ejr_credentials_from_env(
env: Optional[typing.Mapping] = None, *, strict: bool = True
) -> Union[ClientCredentials, None]:
# TODO only really used in openeo-geopyspark-driver atm
# TODO Generalize this functionality (map env vars to NamedTuple) in some way?
env = env or os.environ
env_var_mapping = {
"oidc_issuer": "OPENEO_EJR_OIDC_ISSUER",
"client_id": "OPENEO_EJR_OIDC_CLIENT_ID",
"client_secret": "OPENEO_EJR_OIDC_CLIENT_SECRET",
}
try:
kwargs = {a: env[e] for a, e in env_var_mapping.items()}
except KeyError:
if strict:
missing = set(env_var_mapping.values()).difference(env.keys())
raise EjrError(f"Failed building {ClientCredentials.__name__} from env: missing {missing!r}") from None
else:
return ClientCredentials(**kwargs)


class ElasticJobRegistry(JobRegistryInterface):
Expand All @@ -225,8 +216,7 @@ def __init__(
self.logger.info(f"Creating ElasticJobRegistry with {backend_id=} and {api_url=}")
self._backend_id: Optional[str] = backend_id
self._api_url = api_url
self._authenticator: Optional[OidcClientCredentialsAuthenticator] = None
self._cache = TtlCache(default_ttl=60 * 60)
self._access_token_helper = ClientCredentialsAccessTokenHelper(session=session)

if session:
self._session = session
Expand Down Expand Up @@ -260,34 +250,9 @@ def backend_id(self) -> str:
assert self._backend_id
return self._backend_id

def setup_auth_oidc_client_credentials(
self, credentials: ElasticJobRegistryCredentials
) -> None:
def setup_auth_oidc_client_credentials(self, credentials: ClientCredentials) -> None:
"""Set up OIDC client credentials authentication."""
self.logger.info(
f"Setting up EJR OIDC Client Credentials Authentication with {credentials.client_id=}, {credentials.oidc_issuer=}, {len(credentials.client_secret)=}"
)
oidc_provider = OidcProviderInfo(
issuer=credentials.oidc_issuer, requests_session=self._session
)
client_info = OidcClientInfo(
client_id=credentials.client_id,
provider=oidc_provider,
client_secret=credentials.client_secret,
)
self._authenticator = OidcClientCredentialsAuthenticator(
client_info=client_info, requests_session=self._session
)

def _get_access_token(self) -> str:
if not self._authenticator:
raise EjrError("No authentication set up")
with TimingLogger(
title=f"Requesting EJR OIDC access_token ({self._authenticator.__class__.__name__})",
logger=self.logger.info,
):
tokens = self._authenticator.get_tokens()
return tokens.access_token
self._access_token_helper.setup_credentials(credentials)

def _do_request(
self,
Expand All @@ -301,12 +266,7 @@ def _do_request(
with TimingLogger(logger=self.logger.debug, title=f"EJR Request `{method} {path}`"):
headers = {}
if use_auth:
access_token = self._cache.get_or_call(
key="api_access_token",
callback=self._get_access_token,
# TODO: finetune/optimize caching TTL? Detect TTl/expiry from JWT access token itself?
ttl=30 * 60,
)
access_token = self._access_token_helper.get_access_token()
headers["Authorization"] = f"Bearer {access_token}"

url = url_join(self._api_url, path)
Expand Down Expand Up @@ -672,7 +632,7 @@ def _get_job_registry(
" through environment variable `VAULT_TOKEN`"
" or local file `~/.vault-token` (e.g. created with `vault login -method=ldap username=john`)."
)
credentials = ElasticJobRegistryCredentials.from_mapping(secret["data"]["data"])
credentials = ClientCredentials.from_mapping(secret["data"]["data"])
ejr.setup_auth_oidc_client_credentials(credentials=credentials)
return ejr

Expand Down
109 changes: 109 additions & 0 deletions openeo_driver/util/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import logging
import time
from typing import Mapping, NamedTuple, Optional, Union

import requests
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo

from openeo_driver.datastructs import secretive_repr

_log = logging.getLogger(__name__)


class ClientCredentials(NamedTuple):
"""
Necessary bits for doing OIDC client credentials flow:
issuer URL, client id and secret.
"""

oidc_issuer: str
client_id: str
client_secret: str
__repr__ = __str__ = secretive_repr()

@classmethod
def from_mapping(cls, data: Mapping, *, strict: bool = True) -> Union[ClientCredentials, None]:
"""Build from mapping/dict/config"""
keys = {"oidc_issuer", "client_id", "client_secret"}
try:
kwargs = {k: data[k] for k in keys}
except KeyError:
if strict:
missing = keys.difference(data.keys())
raise ValueError(f"Failed building {cls.__name__} from mapping: missing {missing!r}") from None
else:
return cls(**kwargs)


class _AccessTokenCache(NamedTuple):
access_token: str
expires_at: float


class ClientCredentialsAccessTokenHelper:
"""
Helper to get OIDC access tokens using client credentials flow, e.g. to interact with an API (like EJR, ETL, ...)
Caches access token too.
Usage:
- add an `OidcClientCredentialsHelper` instance to your class (e.g. in __init__)
- call `setup_credentials()` with `ClientCredentials` instance (or do this directly from __init__)
- call `get_access_token()` to get an access token where necessary
"""

__slots__ = ("_authenticator", "_session", "_cache", "_default_ttl")

def __init__(
self,
*,
credentials: Optional[ClientCredentials] = None,
session: Optional[requests.Session] = None,
default_ttl: float = 20 * 60,
):
self._session = session
self._authenticator: Optional[OidcClientCredentialsAuthenticator] = None
self._cache = _AccessTokenCache("", 0)
self._default_ttl = default_ttl

if credentials:
self.setup_credentials(credentials)

def setup_credentials(self, credentials: ClientCredentials) -> None:
"""
Set up an `OidcClientCredentialsAuthenticator`
(that allows to fetch access tokens)
using the given client credentials and OIDC issuer configuration.
"""
# TODO: eliminate need for this separate `setup` and just do it always from `__init__`?
self._cache = _AccessTokenCache("", 0)
_log.debug(f"Setting up {self.__class__.__name__} with {credentials!r}")
oidc_provider = OidcProviderInfo(
issuer=credentials.oidc_issuer,
requests_session=self._session,
)
client_info = OidcClientInfo(
client_id=credentials.client_id,
provider=oidc_provider,
client_secret=credentials.client_secret,
)
self._authenticator = OidcClientCredentialsAuthenticator(
client_info=client_info, requests_session=self._session
)

def _get_access_token(self) -> str:
"""Get an access token using the configured authenticator."""
if not self._authenticator:
raise RuntimeError("No authentication set up")
_log.debug(f"{self.__class__.__name__} getting access token")
tokens = self._authenticator.get_tokens()
return tokens.access_token

def get_access_token(self) -> str:
"""Get an access token using the configured authenticator."""
if time.time() > self._cache.expires_at:
access_token = self._get_access_token()
# TODO: get expiry from access token itself?
self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl)
return self._cache.access_token
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"flask",
"werkzeug>=1.0.1,<2.3.0", # https://github.com/Open-EO/openeo-python-driver/issues/187
"requests>=2.28.0",
"openeo>=0.24.0.a2.dev",
"openeo>=0.24.0.a3.dev",
"openeo_processes==0.0.4", # 0.0.4 is special build/release, also see https://github.com/Open-EO/openeo-python-driver/issues/152
"gunicorn>=20.0.1",
"numpy>=1.22.0",
Expand Down
Loading

0 comments on commit 1bc83b6

Please sign in to comment.