Skip to content

Commit

Permalink
upgrade neuro-config-client (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan authored Dec 23, 2024
1 parent 1d1fa15 commit d371879
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 38 deletions.
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ install_requires =
lark==1.2.2
marshmallow==3.23.2
neuro-auth-client==24.8.0
neuro-config-client==24.12.2
neuro-config-client==24.12.4
neuro-logging==24.12.1
neuro-sdk==23.2.0
pydantic==2.10.4
pydantic-settings==2.7.0
python-dateutil==2.9.0.post0
Expand Down
14 changes: 2 additions & 12 deletions src/platform_reports/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from decimal import Decimal
from importlib.metadata import version
from pathlib import Path
from tempfile import mktemp
from textwrap import dedent

import aiobotocore.session
Expand Down Expand Up @@ -39,7 +37,6 @@
from neuro_auth_client.security import setup_security
from neuro_config_client.client import ConfigClient
from neuro_logging import init_logging, setup_sentry
from neuro_sdk import Client as ApiClient, Factory as ClientFactory
from yarl import URL

from .auth import AuthService
Expand All @@ -65,6 +62,7 @@
Price,
)
from .metrics_service import GetCreditsUsageRequest, MetricsService
from .platform_api_client import ApiClient
from .prometheus_client import PrometheusClient
from .schema import (
ClientErrorSchema,
Expand Down Expand Up @@ -442,16 +440,8 @@ async def run_task(coro: Awaitable[None]) -> AsyncIterator[None]:

@asynccontextmanager
async def create_api_client(config: PlatformServiceConfig) -> AsyncIterator[ApiClient]:
tmp_config = Path(mktemp())
platform_api_factory = ClientFactory(tmp_config)
await platform_api_factory.login_with_token(url=config.url, token=config.token)
client = None
try:
client = await platform_api_factory.get()
async with ApiClient(url=config.url, token=config.token) as client:
yield client
finally:
if client:
await client.close()


def get_aws_pricing_api_region(region: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/platform_reports/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from multidict import MultiMapping
from neuro_auth_client import AuthClient, Permission
from neuro_sdk import Client as ApiClient

from .platform_api_client import ApiClient
from .prometheus_query_parser import (
InstantVector,
LabelMatcher,
Expand Down Expand Up @@ -382,7 +382,7 @@ async def get_job_permissions(self, job_ids: Iterable[str]) -> list[Permission]:
if job_id in self._job_permissions:
result.append(self._job_permissions[job_id])
else:
job = await self._api_client.jobs.status(job_id)
job = await self._api_client.get_job(job_id)
permission = Permission(uri=str(job.uri), action="read")
self._job_permissions[job_id] = permission
result.append(permission)
Expand Down
68 changes: 68 additions & 0 deletions src/platform_reports/platform_api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections.abc import Sequence
from dataclasses import dataclass
from types import TracebackType

import aiohttp
from yarl import URL


@dataclass(frozen=True)
class Job:
id: str
uri: URL


class ApiClient:
_client: aiohttp.ClientSession

def __init__(
self,
url: URL,
token: str | None = None,
timeout: aiohttp.ClientTimeout = aiohttp.client.DEFAULT_TIMEOUT,
trace_configs: Sequence[aiohttp.TraceConfig] = (),
):
super().__init__()

self._base_url = url / "api/v1"
self._token = token
self._timeout = timeout
self._trace_configs = trace_configs

async def __aenter__(self) -> "ApiClient":
self._client = self._create_http_client()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()

def _create_http_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(
headers=self._create_default_headers(),
timeout=self._timeout,
trace_configs=list(self._trace_configs),
)

async def aclose(self) -> None:
assert self._client
await self._client.close()

def _create_default_headers(self) -> dict[str, str]:
result = {}
if self._token:
result["Authorization"] = f"Bearer {self._token}"
return result

async def get_job(self, id_: str) -> Job:
async with self._client.get(self._base_url / "jobs" / id_) as response:
response.raise_for_status()
response_json = await response.json()
return Job(
id=response_json["id"],
uri=URL(response_json["uri"]),
)
4 changes: 2 additions & 2 deletions tests/integration/conftest_platform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ async def _get_cluster(request: aiohttp.web.Request) -> aiohttp.web.Response:
"max_size": 1,
"cpu": 1,
"memory": 4096 * 2**20,
"memory_mb": 4096,
"price": "0.0",
"currency": "USD",
}
Expand Down Expand Up @@ -59,8 +58,9 @@ async def _get_cluster(request: aiohttp.web.Request) -> aiohttp.web.Response:
"min_size": 1,
"max_size": 1,
"cpu": 1,
"available_cpu": 1,
"memory": 4096 * 2**20,
"memory_mb": 4096,
"available_memory": 4096 * 2**20,
"price": "0.0",
"currency": "USD",
"cpu_min_watts": 1,
Expand Down
30 changes: 10 additions & 20 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from decimal import Decimal
from unittest import mock

import pytest
from multidict import MultiDict
from neuro_auth_client import AuthClient, Permission
from neuro_sdk import Client as ApiClient, JobDescription as Job
from yarl import URL

from platform_reports.auth import AuthService, Dashboard
from platform_reports.platform_api_client import ApiClient, Job


JOB_ID = "job-00000000-0000-0000-0000-000000000000"
Expand All @@ -21,16 +20,7 @@ def job_factory() -> Callable[[str], Job]:
def _factory(id_: str) -> Job:
return Job(
id=id_,
owner=None, # type: ignore
cluster_name=None, # type: ignore
status=None, # type: ignore
history=None, # type: ignore
container=None, # type: ignore
uri=URL(f"job://default/org/project/{id_}"),
total_price_credits=Decimal("500"),
price_credits_per_hour=Decimal("5"),
pass_config=None, # type: ignore
scheduler_enabled=False,
)

return _factory
Expand All @@ -49,7 +39,7 @@ async def get_job(id_: str) -> Job:
return job_factory(id_)

client = mock.AsyncMock(ApiClient)
client.jobs.status = mock.AsyncMock(side_effect=get_job)
client.get_job = mock.AsyncMock(side_effect=get_job)
return client


Expand Down Expand Up @@ -200,7 +190,7 @@ async def test_check_job_dashboard_with_job_id_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_project_jobs_dashboard_without_project_name_permissions(
self, service: AuthService, auth_client: mock.AsyncMock
Expand Down Expand Up @@ -386,7 +376,7 @@ async def test_check_kube_state_metrics_query_with_pod_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_kube_state_metrics_query_with_service_pod_permissions(
self, service: AuthService, auth_client: mock.AsyncMock
Expand Down Expand Up @@ -455,7 +445,7 @@ async def test_check_kubelet_query_with_pod_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_kubelet_query_with_service_pod_permissions(
self,
Expand Down Expand Up @@ -527,7 +517,7 @@ async def test_check_nvidia_dcgm_exporter_query_with_pod_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_neuro_metrics_exporter_query_without_pod_permissions(
self, service: AuthService, auth_client: mock.AsyncMock
Expand Down Expand Up @@ -582,7 +572,7 @@ async def test_check_neuro_metrics_exporter_query_with_pod_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_without_job_matcher(self, service: AuthService) -> None:
result = await service.check_query_permissions(
Expand Down Expand Up @@ -613,7 +603,7 @@ async def test_check_join_for_job_permissions(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_ignoring_join_for_all_jobs_permissions(
self,
Expand All @@ -635,7 +625,7 @@ async def test_check_ignoring_join_for_all_jobs_permissions(
auth_client.get_missing_permissions.assert_awaited_once_with(
"user", [Permission(uri="role://default/manager", action="read")]
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_join_platform_api_called_once(
self,
Expand All @@ -658,7 +648,7 @@ async def test_check_join_platform_api_called_once(
"user",
[Permission(uri=f"job://default/org/project/{JOB_ID}", action="read")],
)
api_client.jobs.status.assert_awaited_once_with(JOB_ID)
api_client.get_job.assert_awaited_once_with(JOB_ID)

async def test_check_join_for_project_jobs_permissions(
self, service: AuthService, auth_client: mock.AsyncMock
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,12 @@ def cluster(self) -> Cluster:
node_pools=[
NodePool(
name="node-pool",
cpu=1,
available_cpu=1,
memory=4 * 2**30,
available_memory=4 * 2**30,
disk_size=100 * 2**30,
available_disk_size=100 * 2**30,
cpu_min_watts=10.5,
cpu_max_watts=110.0,
),
Expand Down

0 comments on commit d371879

Please sign in to comment.