Skip to content

Commit

Permalink
support aws spot prices (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan authored Nov 9, 2020
1 parent 1b30da0 commit f6a9d2a
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 20 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ CURRENT_TIME = $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')

AWS_ACCOUNT_ID ?= 771188043543
AWS_REGION ?= us-east-1

AZURE_RG_NAME ?= dev
AZURE_ACR_NAME ?= crc570d91c95c6aac0ea80afb1019a0c6f

TAG ?= latest
Expand Down
1 change: 1 addition & 0 deletions minikube.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ function minikube::start {
minikube start --kubernetes-version=v1.16.15 --wait=all --wait-timeout=5m
kubectl config use-context minikube
kubectl label node minikube \
topology.kubernetes.io/zone=minikube-zone \
node.kubernetes.io/instance-type=minikube \
platform.neuromation.io/nodepool=minikube-node-pool
}
Expand Down
18 changes: 16 additions & 2 deletions platform_reports/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:

kube_client = await exit_stack.enter_async_context(KubeClient(config.kube))
node = await kube_client.get_node(config.node_name)
zone = (
node.metadata.labels.get("failure-domain.beta.kubernetes.io/zone")
or node.metadata.labels.get("topology.kubernetes.io/zone")
or ""
)
app["zone"] = zone
logger.info("Node is in zone %s", zone)

instance_type = (
node.metadata.labels.get("node.kubernetes.io/instance-type")
or node.metadata.labels.get("beta.kubernetes.io/instance-type")
Expand All @@ -380,8 +388,7 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:
app["instance_type"] = instance_type
logger.info("Node instance type is %s", instance_type)

preemptible = node.metadata.labels.get(config.node_preemptible_label, "")
is_preemptible = preemptible.lower() == "true"
is_preemptible = config.node_preemptible_label in node.metadata.labels
if is_preemptible:
logger.info("Node is preemptible")
else:
Expand All @@ -393,18 +400,25 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:

if config.cloud_provider == "aws":
assert config.region
assert zone
assert instance_type
session = aiobotocore.get_session()
pricing_client = await exit_stack.enter_async_context(
session.create_client(
"pricing", get_aws_pricing_api_region(config.region)
)
)
ec2_client = await exit_stack.enter_async_context(
session.create_client("ec2", config.region)
)
node_price_collector = await exit_stack.enter_async_context(
AWSNodePriceCollector(
pricing_client=pricing_client,
ec2_client=ec2_client,
region=config.region,
zone=zone,
instance_type=instance_type,
is_spot=is_preemptible,
)
)
elif config.cloud_provider == "gcp":
Expand Down
29 changes: 29 additions & 0 deletions platform_reports/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from importlib.resources import path
from pathlib import Path
from types import TracebackType
Expand Down Expand Up @@ -96,15 +97,21 @@ class AWSNodePriceCollector(Collector[Price]):
def __init__(
self,
pricing_client: AioBaseClient,
ec2_client: AioBaseClient,
region: str,
instance_type: str,
zone: str,
is_spot: bool,
interval_s: float = 3600,
) -> None:
super().__init__(Price(), interval_s)
self._pricing_client = pricing_client
self._ec2_client = ec2_client
self._region = region
self._region_long_name = ""
self._zone = zone
self._instance_type = instance_type
self._is_spot = is_spot

async def __aenter__(self) -> Collector[Price]:
await super().__aenter__()
Expand Down Expand Up @@ -134,6 +141,11 @@ def _get_region_long_names(self) -> Dict[str, str]:
return result

async def get_latest_value(self) -> Price:
if self._is_spot:
return await self._get_latest_spot_price()
return await self._get_latest_on_demand_price()

async def _get_latest_on_demand_price(self) -> Price:
response = await self._pricing_client.get_products(
ServiceCode="AmazonEC2",
FormatVersion="aws_v1",
Expand Down Expand Up @@ -170,6 +182,23 @@ async def get_latest_value(self) -> Price:
def _create_filter(self, field: str, value: str) -> Dict[str, str]:
return {"Type": "TERM_MATCH", "Field": field, "Value": value}

async def _get_latest_spot_price(self) -> Price:
response = await self._ec2_client.describe_spot_price_history(
AvailabilityZone=self._zone,
InstanceTypes=[self._instance_type],
ProductDescriptions=["Linux/UNIX"],
StartTime=datetime.utcnow(),
)
history = response["SpotPriceHistory"]
if len(history) == 0:
logger.warning(
"AWS didn't return spot price history for %s instance in %s zone",
self._instance_type,
self._zone,
)
return Price()
return Price(currency="USD", value=history[0]["SpotPrice"])


class AzureNodePriceCollector(Collector[Price]):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def _create(metrics_config: MetricsConfig) -> AsyncIterator[URL]:
app=app,
port=metrics_config.server.port,
) as address:
assert app["zone"] == "minikube-zone"
assert app["instance_type"] == "minikube"
assert app["node_pool_name"] == "minikube-node-pool"
yield URL.build(scheme="http", host=address.host, port=address.port)
Expand Down
92 changes: 74 additions & 18 deletions tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import json
from contextlib import contextmanager, suppress
from contextlib import asynccontextmanager, contextmanager, suppress
from pathlib import Path
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
Expand Down Expand Up @@ -79,20 +80,35 @@ def pricing_client(self) -> mock.AsyncMock:
return mock.AsyncMock()

@pytest.fixture
async def price_collector(
self, pricing_client: AioBaseClient
) -> AsyncIterator[AWSNodePriceCollector]:
async with AWSNodePriceCollector(
pricing_client=pricing_client,
region="us-east-1",
instance_type="p2.xlarge",
interval_s=0.1,
) as result:
assert isinstance(result, AWSNodePriceCollector)
yield result
def ec2_client(self) -> mock.AsyncMock:
return mock.AsyncMock()

@pytest.fixture
def collector_factory(
self, pricing_client: AioBaseClient, ec2_client: AioBaseClient
) -> Callable[..., AsyncContextManager[AWSNodePriceCollector]]:
@asynccontextmanager
async def _create(
is_spot: bool = False,
) -> AsyncIterator[AWSNodePriceCollector]:
async with AWSNodePriceCollector(
pricing_client=pricing_client,
ec2_client=ec2_client,
region="us-east-1",
zone="us-east-1a",
instance_type="p2.xlarge",
is_spot=is_spot,
interval_s=0.1,
) as result:
assert isinstance(result, AWSNodePriceCollector)
yield result

return _create

async def test_get_latest_price_per_hour(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
pricing_client.get_products.return_value = {
"PriceList": [
Expand All @@ -114,7 +130,8 @@ async def test_get_latest_price_per_hour(
]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

pricing_client.get_products.assert_awaited_once_with(
ServiceCode="AmazonEC2",
Expand All @@ -137,7 +154,9 @@ async def test_get_latest_price_per_hour(
assert result == Price(currency="USD", value=0.1)

async def test_get_latest_price_per_hour_with_multiple_prices(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
price_item = {
"terms": {
Expand All @@ -156,12 +175,15 @@ async def test_get_latest_price_per_hour_with_multiple_prices(
"PriceList": [json.dumps(price_item), json.dumps(price_item)]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

assert result == Price()

async def test_get_latest_price_per_hour_with_unsupported_currency(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
pricing_client.get_products.return_value = {
"PriceList": [
Expand All @@ -183,7 +205,41 @@ async def test_get_latest_price_per_hour_with_unsupported_currency(
]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

assert result == Price()

async def test_get_latest_spot_price_per_hour(
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
ec2_client: mock.AsyncMock,
) -> None:
ec2_client.describe_spot_price_history.return_value = {
"SpotPriceHistory": [{"SpotPrice": 0.27}]
}

async with collector_factory(is_spot=True) as collector:
result = await collector.get_latest_value()

assert result == Price(currency="USD", value=0.27)

ec2_client.describe_spot_price_history.assert_awaited_once_with(
AvailabilityZone="us-east-1a",
InstanceTypes=["p2.xlarge"],
ProductDescriptions=["Linux/UNIX"],
StartTime=mock.ANY,
)

async def test_get_latest_spot_price_per_hour_no_history(
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
ec2_client: mock.AsyncMock,
) -> None:
ec2_client.describe_spot_price_history.return_value = {"SpotPriceHistory": []}

async with collector_factory(is_spot=True) as collector:
result = await collector.get_latest_value()

assert result == Price()

Expand Down

0 comments on commit f6a9d2a

Please sign in to comment.