diff --git a/platform_secrets/kube_client.py b/platform_secrets/kube_client.py index fe9dade..62a28e6 100644 --- a/platform_secrets/kube_client.py +++ b/platform_secrets/kube_client.py @@ -1,6 +1,8 @@ +import asyncio import json import logging import ssl +from contextlib import suppress from io import BytesIO from pathlib import Path from typing import Any, Optional @@ -54,6 +56,7 @@ def __init__( auth_cert_key_path: Optional[str] = None, token: Optional[str] = None, token_path: Optional[str] = None, + token_update_interval_s: int = 300, conn_timeout_s: int = 300, read_timeout_s: int = 100, conn_pool_size: int = 100, @@ -70,6 +73,7 @@ def __init__( self._auth_cert_key_path = auth_cert_key_path self._token = token self._token_path = token_path + self._token_update_interval_s = token_update_interval_s self._conn_timeout_s = conn_timeout_s self._read_timeout_s = read_timeout_s @@ -77,6 +81,7 @@ def __init__( self._trace_configs = trace_configs self._client: Optional[aiohttp.ClientSession] = None + self._token_updater_task: Optional[asyncio.Task[None]] = None self._dummy_secret_key = SECRET_DUMMY_KEY @@ -98,34 +103,36 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]: return ssl_context async def init(self) -> None: - self._client = await self.create_http_client() - - async def init_if_needed(self) -> None: - if not self._client or self._client.closed: - await self.init() - - async def create_http_client(self) -> aiohttp.ClientSession: connector = aiohttp.TCPConnector( limit=self._conn_pool_size, ssl=self._create_ssl_context() ) - if self._auth_type == KubeClientAuthType.TOKEN: - token = self._token - if not token: - assert self._token_path is not None - token = Path(self._token_path).read_text() - headers = {"Authorization": "Bearer " + token} - else: - headers = {} + if self._token_path: + self._token = Path(self._token_path).read_text() + self._token_updater_task = asyncio.create_task(self._start_token_updater()) timeout = aiohttp.ClientTimeout( connect=self._conn_timeout_s, total=self._read_timeout_s ) - return aiohttp.ClientSession( + self._client = aiohttp.ClientSession( connector=connector, timeout=timeout, - headers=headers, trace_configs=self._trace_configs, ) + async def _start_token_updater(self) -> None: + if not self._token_path: + return + while True: + try: + token = Path(self._token_path).read_text() + if token != self._token: + self._token = token + logger.info("Kube token was refreshed") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.exception("Failed to update kube token: %s", exc) + await asyncio.sleep(self._token_update_interval_s) + @property def namespace(self) -> str: return self._namespace @@ -134,6 +141,11 @@ async def close(self) -> None: if self._client: await self._client.close() self._client = None + if self._token_updater_task: + self._token_updater_task.cancel() + with suppress(asyncio.CancelledError): + await self._token_updater_task + self._token_updater_task = None async def __aenter__(self) -> "KubeClient": await self.init() @@ -160,23 +172,22 @@ def _generate_secret_url( all_secrets_url = self._generate_all_secrets_url(namespace_name) return f"{all_secrets_url}/{secret_name}" + def _create_headers( + self, headers: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + headers = dict(headers) if headers else {} + if self._auth_type == KubeClientAuthType.TOKEN and self._token: + headers["Authorization"] = "Bearer " + self._token + return headers + async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - await self.init_if_needed() + headers = self._create_headers(kwargs.pop("headers", None)) assert self._client, "client is not initialized" - doing_retry = kwargs.pop("doing_retry", False) - - async with self._client.request(*args, **kwargs) as response: + async with self._client.request(*args, headers=headers, **kwargs) as response: payload = await response.json() - try: + logging.debug("k8s response payload: %s", payload) self._raise_for_status(payload) return payload - except KubeClientUnauthorized: - if doing_retry: - raise - # K8s SA's token might be stale, need to refresh it and retry - await self._reload_http_client() - kwargs["doing_retry"] = True - return await self._request(*args, **kwargs) def _raise_for_status(self, payload: dict[str, Any]) -> None: kind = payload["kind"] @@ -196,11 +207,6 @@ def _raise_for_status(self, payload: dict[str, Any]) -> None: raise ResourceConflict(payload["message"]) raise KubeClientException(payload["message"]) - async def _reload_http_client(self) -> None: - await self.close() - self._token = None - await self.init() - async def create_secret( self, secret_name: str, diff --git a/tests/integration/test_kube_client.py b/tests/integration/test_kube_client.py new file mode 100644 index 0000000..fbb10a3 --- /dev/null +++ b/tests/integration/test_kube_client.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +import os +import tempfile +from collections.abc import AsyncIterator, Iterator +from pathlib import Path +from typing import Any + +import aiohttp +import aiohttp.web +import pytest + +from platform_secrets.config import KubeClientAuthType +from platform_secrets.kube_client import KubeClient + +from .conftest import create_local_app_server + + +class TestKubeClientTokenUpdater: + @pytest.fixture + async def kube_app(self) -> aiohttp.web.Application: + async def _get_secrets(request: aiohttp.web.Request) -> aiohttp.web.Response: + auth = request.headers["Authorization"] + token = auth.split()[-1] + app["token"]["value"] = token + return aiohttp.web.json_response({"kind": "SecretList", "items": []}) + + app = aiohttp.web.Application() + app["token"] = {"value": ""} + app.router.add_routes( + [aiohttp.web.get("/api/v1/namespaces/default/secrets", _get_secrets)] + ) + return app + + @pytest.fixture + async def kube_server( + self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any + ) -> AsyncIterator[str]: + async with create_local_app_server( + kube_app, port=unused_tcp_port_factory() + ) as address: + yield f"http://{address.host}:{address.port}" + + @pytest.fixture + def kube_token_path(self) -> Iterator[str]: + _, path = tempfile.mkstemp() + Path(path).write_text("token-1") + yield path + os.remove(path) + + @pytest.fixture + async def kube_client( + self, kube_server: str, kube_token_path: str + ) -> AsyncIterator[KubeClient]: + async with KubeClient( + base_url=kube_server, + namespace="default", + auth_type=KubeClientAuthType.TOKEN, + token_path=kube_token_path, + token_update_interval_s=1, + ) as client: + yield client + + async def test_token_periodically_updated( + self, + kube_app: aiohttp.web.Application, + kube_client: KubeClient, + kube_token_path: str, + ) -> None: + await kube_client.list_secrets() + assert kube_app["token"]["value"] == "token-1" + + Path(kube_token_path).write_text("token-2") + await asyncio.sleep(2) + + await kube_client.list_secrets() + assert kube_app["token"]["value"] == "token-2"