Skip to content

Commit

Permalink
add update kube token task
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Zubenko committed Jul 17, 2023
1 parent 5b01fe7 commit 95d94ac
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 34 deletions.
74 changes: 40 additions & 34 deletions platform_secrets/kube_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json
import logging
import ssl
from contextlib import suppress
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(
auth_cert_key_path: Optional[str] = None,
token: Optional[str] = None,
token_path: Optional[str] = None,
token_update_interval_s: int = 300,
conn_timeout_s: int = 300,
read_timeout_s: int = 100,
conn_pool_size: int = 100,
Expand All @@ -70,13 +73,15 @@ def __init__(
self._auth_cert_key_path = auth_cert_key_path
self._token = token
self._token_path = token_path
self._token_update_interval_s = token_update_interval_s

self._conn_timeout_s = conn_timeout_s
self._read_timeout_s = read_timeout_s
self._conn_pool_size = conn_pool_size
self._trace_configs = trace_configs

self._client: Optional[aiohttp.ClientSession] = None
self._token_updater_task: Optional[asyncio.Task[None]] = None

self._dummy_secret_key = SECRET_DUMMY_KEY

Expand All @@ -98,34 +103,36 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
return ssl_context

async def init(self) -> None:
self._client = await self.create_http_client()

async def init_if_needed(self) -> None:
if not self._client or self._client.closed:
await self.init()

async def create_http_client(self) -> aiohttp.ClientSession:
connector = aiohttp.TCPConnector(
limit=self._conn_pool_size, ssl=self._create_ssl_context()
)
if self._auth_type == KubeClientAuthType.TOKEN:
token = self._token
if not token:
assert self._token_path is not None
token = Path(self._token_path).read_text()
headers = {"Authorization": "Bearer " + token}
else:
headers = {}
if self._token_path:
self._token = Path(self._token_path).read_text()
self._token_updater_task = asyncio.create_task(self._start_token_updater())
timeout = aiohttp.ClientTimeout(
connect=self._conn_timeout_s, total=self._read_timeout_s
)
return aiohttp.ClientSession(
self._client = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers=headers,
trace_configs=self._trace_configs,
)

async def _start_token_updater(self) -> None:
if not self._token_path:
return
while True:
try:
token = Path(self._token_path).read_text()
if token != self._token:
self._token = token
logger.info("Kube token was refreshed")
except asyncio.CancelledError:
raise
except Exception as exc:
logger.exception("Failed to update kube token: %s", exc)
await asyncio.sleep(self._token_update_interval_s)

@property
def namespace(self) -> str:
return self._namespace
Expand All @@ -134,6 +141,11 @@ async def close(self) -> None:
if self._client:
await self._client.close()
self._client = None
if self._token_updater_task:
self._token_updater_task.cancel()
with suppress(asyncio.CancelledError):
await self._token_updater_task
self._token_updater_task = None

async def __aenter__(self) -> "KubeClient":
await self.init()
Expand All @@ -160,23 +172,22 @@ def _generate_secret_url(
all_secrets_url = self._generate_all_secrets_url(namespace_name)
return f"{all_secrets_url}/{secret_name}"

def _create_headers(
self, headers: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
headers = dict(headers) if headers else {}
if self._auth_type == KubeClientAuthType.TOKEN and self._token:
headers["Authorization"] = "Bearer " + self._token
return headers

async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
await self.init_if_needed()
headers = self._create_headers(kwargs.pop("headers", None))
assert self._client, "client is not initialized"
doing_retry = kwargs.pop("doing_retry", False)

async with self._client.request(*args, **kwargs) as response:
async with self._client.request(*args, headers=headers, **kwargs) as response:
payload = await response.json()
try:
logging.debug("k8s response payload: %s", payload)
self._raise_for_status(payload)
return payload
except KubeClientUnauthorized:
if doing_retry:
raise
# K8s SA's token might be stale, need to refresh it and retry
await self._reload_http_client()
kwargs["doing_retry"] = True
return await self._request(*args, **kwargs)

def _raise_for_status(self, payload: dict[str, Any]) -> None:
kind = payload["kind"]
Expand All @@ -196,11 +207,6 @@ def _raise_for_status(self, payload: dict[str, Any]) -> None:
raise ResourceConflict(payload["message"])
raise KubeClientException(payload["message"])

async def _reload_http_client(self) -> None:
await self.close()
self._token = None
await self.init()

async def create_secret(
self,
secret_name: str,
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/test_kube_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import asyncio
import os
import tempfile
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from typing import Any

import aiohttp
import aiohttp.web
import pytest

from platform_secrets.config import KubeClientAuthType
from platform_secrets.kube_client import KubeClient

from .conftest import create_local_app_server


class TestKubeClientTokenUpdater:
@pytest.fixture
async def kube_app(self) -> aiohttp.web.Application:
async def _get_secrets(request: aiohttp.web.Request) -> aiohttp.web.Response:
auth = request.headers["Authorization"]
token = auth.split()[-1]
app["token"]["value"] = token
return aiohttp.web.json_response({"kind": "SecretList", "items": []})

app = aiohttp.web.Application()
app["token"] = {"value": ""}
app.router.add_routes(
[aiohttp.web.get("/api/v1/namespaces/default/secrets", _get_secrets)]
)
return app

@pytest.fixture
async def kube_server(
self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any
) -> AsyncIterator[str]:
async with create_local_app_server(
kube_app, port=unused_tcp_port_factory()
) as address:
yield f"http://{address.host}:{address.port}"

@pytest.fixture
def kube_token_path(self) -> Iterator[str]:
_, path = tempfile.mkstemp()
Path(path).write_text("token-1")
yield path
os.remove(path)

@pytest.fixture
async def kube_client(
self, kube_server: str, kube_token_path: str
) -> AsyncIterator[KubeClient]:
async with KubeClient(
base_url=kube_server,
namespace="default",
auth_type=KubeClientAuthType.TOKEN,
token_path=kube_token_path,
token_update_interval_s=1,
) as client:
yield client

async def test_token_periodically_updated(
self,
kube_app: aiohttp.web.Application,
kube_client: KubeClient,
kube_token_path: str,
) -> None:
await kube_client.list_secrets()
assert kube_app["token"]["value"] == "token-1"

Path(kube_token_path).write_text("token-2")
await asyncio.sleep(2)

await kube_client.list_secrets()
assert kube_app["token"]["value"] == "token-2"

0 comments on commit 95d94ac

Please sign in to comment.