Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support aws spot prices #84

Merged
merged 1 commit into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ CURRENT_TIME = $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')

AWS_ACCOUNT_ID ?= 771188043543
AWS_REGION ?= us-east-1

AZURE_RG_NAME ?= dev
AZURE_ACR_NAME ?= crc570d91c95c6aac0ea80afb1019a0c6f

TAG ?= latest
Expand Down
1 change: 1 addition & 0 deletions minikube.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ function minikube::start {
minikube start --kubernetes-version=v1.16.15 --wait=all --wait-timeout=5m
kubectl config use-context minikube
kubectl label node minikube \
topology.kubernetes.io/zone=minikube-zone \
node.kubernetes.io/instance-type=minikube \
platform.neuromation.io/nodepool=minikube-node-pool
}
Expand Down
18 changes: 16 additions & 2 deletions platform_reports/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:

kube_client = await exit_stack.enter_async_context(KubeClient(config.kube))
node = await kube_client.get_node(config.node_name)
zone = (
node.metadata.labels.get("failure-domain.beta.kubernetes.io/zone")
or node.metadata.labels.get("topology.kubernetes.io/zone")
or ""
)
app["zone"] = zone
logger.info("Node is in zone %s", zone)

instance_type = (
node.metadata.labels.get("node.kubernetes.io/instance-type")
or node.metadata.labels.get("beta.kubernetes.io/instance-type")
Expand All @@ -380,8 +388,7 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:
app["instance_type"] = instance_type
logger.info("Node instance type is %s", instance_type)

preemptible = node.metadata.labels.get(config.node_preemptible_label, "")
is_preemptible = preemptible.lower() == "true"
is_preemptible = config.node_preemptible_label in node.metadata.labels
if is_preemptible:
logger.info("Node is preemptible")
else:
Expand All @@ -393,18 +400,25 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:

if config.cloud_provider == "aws":
assert config.region
assert zone
assert instance_type
session = aiobotocore.get_session()
pricing_client = await exit_stack.enter_async_context(
session.create_client(
"pricing", get_aws_pricing_api_region(config.region)
)
)
ec2_client = await exit_stack.enter_async_context(
session.create_client("ec2", config.region)
)
node_price_collector = await exit_stack.enter_async_context(
AWSNodePriceCollector(
pricing_client=pricing_client,
ec2_client=ec2_client,
region=config.region,
zone=zone,
instance_type=instance_type,
is_spot=is_preemptible,
)
)
elif config.cloud_provider == "gcp":
Expand Down
29 changes: 29 additions & 0 deletions platform_reports/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
from dataclasses import dataclass
from datetime import datetime
from importlib.resources import path
from pathlib import Path
from types import TracebackType
Expand Down Expand Up @@ -93,15 +94,21 @@ class AWSNodePriceCollector(Collector[Price]):
def __init__(
self,
pricing_client: AioBaseClient,
ec2_client: AioBaseClient,
region: str,
instance_type: str,
zone: str,
is_spot: bool,
interval_s: float = 3600,
) -> None:
super().__init__(Price(), interval_s)
self._pricing_client = pricing_client
self._ec2_client = ec2_client
self._region = region
self._region_long_name = ""
self._zone = zone
self._instance_type = instance_type
self._is_spot = is_spot

async def __aenter__(self) -> Collector[Price]:
await super().__aenter__()
Expand Down Expand Up @@ -131,6 +138,11 @@ def _get_region_long_names(self) -> Dict[str, str]:
return result

async def get_latest_value(self) -> Price:
if self._is_spot:
return await self._get_latest_spot_price()
return await self._get_latest_on_demand_price()

async def _get_latest_on_demand_price(self) -> Price:
response = await self._pricing_client.get_products(
ServiceCode="AmazonEC2",
FormatVersion="aws_v1",
Expand Down Expand Up @@ -167,6 +179,23 @@ async def get_latest_value(self) -> Price:
def _create_filter(self, field: str, value: str) -> Dict[str, str]:
return {"Type": "TERM_MATCH", "Field": field, "Value": value}

async def _get_latest_spot_price(self) -> Price:
response = await self._ec2_client.describe_spot_price_history(
AvailabilityZone=self._zone,
InstanceTypes=[self._instance_type],
ProductDescriptions=["Linux/UNIX"],
StartTime=datetime.utcnow(),
)
history = response["SpotPriceHistory"]
if len(history) == 0:
logger.warning(
"AWS didn't return spot price history for %s instance in %s zone",
self._instance_type,
self._zone,
)
return Price()
return Price(currency="USD", value=history[0]["SpotPrice"])


class AzureNodePriceCollector(Collector[Price]):
def __init__(self, instance_type: str, interval_s: float = 3600) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def _create(metrics_config: MetricsConfig) -> AsyncIterator[URL]:
app=app,
port=metrics_config.server.port,
) as address:
assert app["zone"] == "minikube-zone"
assert app["instance_type"] == "minikube"
assert app["node_pool_name"] == "minikube-node-pool"
yield URL.build(scheme="http", host=address.host, port=address.port)
Expand Down
92 changes: 74 additions & 18 deletions tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import json
from contextlib import contextmanager, suppress
from contextlib import asynccontextmanager, contextmanager, suppress
from pathlib import Path
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
ContextManager,
Expand Down Expand Up @@ -75,20 +76,35 @@ def pricing_client(self) -> mock.AsyncMock:
return mock.AsyncMock()

@pytest.fixture
async def price_collector(
self, pricing_client: AioBaseClient
) -> AsyncIterator[AWSNodePriceCollector]:
async with AWSNodePriceCollector(
pricing_client=pricing_client,
region="us-east-1",
instance_type="p2.xlarge",
interval_s=0.1,
) as result:
assert isinstance(result, AWSNodePriceCollector)
yield result
def ec2_client(self) -> mock.AsyncMock:
return mock.AsyncMock()

@pytest.fixture
def collector_factory(
self, pricing_client: AioBaseClient, ec2_client: AioBaseClient
) -> Callable[..., AsyncContextManager[AWSNodePriceCollector]]:
@asynccontextmanager
async def _create(
is_spot: bool = False,
) -> AsyncIterator[AWSNodePriceCollector]:
async with AWSNodePriceCollector(
pricing_client=pricing_client,
ec2_client=ec2_client,
region="us-east-1",
zone="us-east-1a",
instance_type="p2.xlarge",
is_spot=is_spot,
interval_s=0.1,
) as result:
assert isinstance(result, AWSNodePriceCollector)
yield result

return _create

async def test_get_latest_price_per_hour(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
pricing_client.get_products.return_value = {
"PriceList": [
Expand All @@ -110,7 +126,8 @@ async def test_get_latest_price_per_hour(
]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

pricing_client.get_products.assert_awaited_once_with(
ServiceCode="AmazonEC2",
Expand All @@ -133,7 +150,9 @@ async def test_get_latest_price_per_hour(
assert result == Price(currency="USD", value=0.1)

async def test_get_latest_price_per_hour_with_multiple_prices(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
price_item = {
"terms": {
Expand All @@ -152,12 +171,15 @@ async def test_get_latest_price_per_hour_with_multiple_prices(
"PriceList": [json.dumps(price_item), json.dumps(price_item)]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

assert result == Price()

async def test_get_latest_price_per_hour_with_unsupported_currency(
self, price_collector: AWSNodePriceCollector, pricing_client: mock.AsyncMock
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
pricing_client: mock.AsyncMock,
) -> None:
pricing_client.get_products.return_value = {
"PriceList": [
Expand All @@ -179,7 +201,41 @@ async def test_get_latest_price_per_hour_with_unsupported_currency(
]
}

result = await price_collector.get_latest_value()
async with collector_factory() as collector:
result = await collector.get_latest_value()

assert result == Price()

async def test_get_latest_spot_price_per_hour(
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
ec2_client: mock.AsyncMock,
) -> None:
ec2_client.describe_spot_price_history.return_value = {
"SpotPriceHistory": [{"SpotPrice": 0.27}]
}

async with collector_factory(is_spot=True) as collector:
result = await collector.get_latest_value()

assert result == Price(currency="USD", value=0.27)

ec2_client.describe_spot_price_history.assert_awaited_once_with(
AvailabilityZone="us-east-1a",
InstanceTypes=["p2.xlarge"],
ProductDescriptions=["Linux/UNIX"],
StartTime=mock.ANY,
)

async def test_get_latest_spot_price_per_hour_no_history(
self,
collector_factory: Callable[..., AsyncContextManager[AWSNodePriceCollector]],
ec2_client: mock.AsyncMock,
) -> None:
ec2_client.describe_spot_price_history.return_value = {"SpotPriceHistory": []}

async with collector_factory(is_spot=True) as collector:
result = await collector.get_latest_value()

assert result == Price()

Expand Down