diff --git a/prereise/gather/__init__.py b/prereise/gather/__init__.py
index 30dda1d94..f520c45d7 100644
--- a/prereise/gather/__init__.py
+++ b/prereise/gather/__init__.py
@@ -5,4 +5,5 @@
"hydrodata",
"get_monthly_net_generation",
"trim_eia_form_923",
+ "request_util",
]
diff --git a/prereise/gather/request_util.py b/prereise/gather/request_util.py
new file mode 100644
index 000000000..0b11d0485
--- /dev/null
+++ b/prereise/gather/request_util.py
@@ -0,0 +1,61 @@
+import functools
+import time
+from datetime import timedelta
+from urllib.error import HTTPError
+
+
+class RateLimit:
+ """Provides a way to call an arbitrary function at most once per interval.
+
+ :param int/float interval: the amount of time in seconds to wait between actions
+ """
+
+ def __init__(self, interval=None):
+ """Constructor"""
+ self.interval = interval
+ self.last_run_at = None if interval is None else time.time() - interval
+
+ def invoke(self, action):
+ """Call the action and return its value, waiting if necessary
+
+ :param callable action: the thing to do
+ :return: (*Any*) -- the return value of the action
+ """
+ if self.interval is None:
+ return action()
+ elapsed = time.time() - self.last_run_at
+ if elapsed < self.interval:
+ time.sleep(self.interval - elapsed)
+ result = action()
+ self.last_run_at = time.time()
+ return result
+
+
+def retry(_func=None, retry_count=5, interval=None, allowed_exceptions=(HTTPError)):
+ """Creates a decorator to handle retry logic.
+
+ :param int retry_count: the max number of retries
+ :param int/float interval: minimum spacing between retries
+ :param tuple allowed_exceptions: exceptions for which the function will be retried, all others will be surfaced to the caller
+
+ :return: (*Any*) -- the return value of the decorated function
+ """
+
+ def decorator(func):
+ limiter = RateLimit(interval)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ for _ in range(retry_count):
+ try:
+ return limiter.invoke(lambda: func(*args, **kwargs))
+ except allowed_exceptions as e:
+ print(str(e))
+ pass
+ print("Max retries reached!!")
+
+ return wrapper
+
+ if _func is None:
+ return decorator
+ return decorator(_func)
diff --git a/prereise/gather/solardata/nsrdb/__init__.py b/prereise/gather/solardata/nsrdb/__init__.py
index dcffbf207..3d1c5013d 100644
--- a/prereise/gather/solardata/nsrdb/__init__.py
+++ b/prereise/gather/solardata/nsrdb/__init__.py
@@ -1 +1 @@
-__all__ = ["naive", "sam"]
+__all__ = ["naive", "sam", "nrel_api"]
diff --git a/prereise/gather/solardata/nsrdb/demo/nsrdb_sam_demo.ipynb b/prereise/gather/solardata/nsrdb/demo/nsrdb_sam_demo.ipynb
index 6534e4edb..8f6341aa1 100644
--- a/prereise/gather/solardata/nsrdb/demo/nsrdb_sam_demo.ipynb
+++ b/prereise/gather/solardata/nsrdb/demo/nsrdb_sam_demo.ipynb
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -46,7 +46,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -71,7 +71,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -296,7 +296,7 @@
"[5 rows x 35 columns]"
]
},
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -307,7 +307,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -332,17 +332,19 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "email = input(prompt='email=')\n",
+ "key = getpass(prompt='api_key=')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
"metadata": {},
"outputs": [
- {
- "name": "stdin",
- "output_type": "stream",
- "text": [
- "email= jon.hagg@breakthroughenergy.org\n",
- "api_key= ········································\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
@@ -359,19 +361,255 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 341/341 [17:25<00:00, 3.86s/it]\n"
+ " 2%|▏ | 8/341 [00:23<15:54, 2.87s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 4%|▍ | 13/341 [00:37<15:16, 2.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 10%|█ | 35/341 [01:37<15:14, 2.99s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 11%|█ | 38/341 [01:44<13:23, 2.65s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 14%|█▍ | 47/341 [02:10<14:39, 2.99s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 16%|█▌ | 53/341 [02:26<13:05, 2.73s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 23%|██▎ | 80/341 [03:56<17:05, 3.93s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 25%|██▍ | 85/341 [04:12<14:51, 3.48s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 50%|████▉ | 169/341 [08:32<08:21, 2.92s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 59%|█████▉ | 202/341 [10:08<06:33, 2.83s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 75%|███████▍ | 255/341 [12:33<03:43, 2.60s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 80%|███████▉ | 272/341 [13:17<02:58, 2.59s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 81%|████████ | 276/341 [13:28<02:53, 2.68s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 88%|████████▊ | 299/341 [14:27<01:46, 2.54s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 89%|████████▉ | 303/341 [14:38<01:44, 2.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 94%|█████████▍| 321/341 [15:27<00:51, 2.58s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 99%|█████████▉| 337/341 [16:09<00:10, 2.56s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "HTTP Error 429: Too Many Requests\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 341/341 [16:20<00:00, 2.73s/it]\n"
]
}
],
"source": [
- "email = input(prompt='email=')\n",
- "key = getpass(prompt='api_key=')\n",
- "data = sam.retrieve_data(solar_plant, email, key)"
+ "data = sam.retrieve_data(solar_plant, email, key, rate_limit=.5)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -570,7 +808,7 @@
"19 0.0 493 2016-01-01 1"
]
},
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -581,7 +819,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -675,7 +913,7 @@
"max 4.277410e+02 1.372600e+04 8.784000e+03"
]
},
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -694,7 +932,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -712,7 +950,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -721,7 +959,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -748,7 +986,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -775,7 +1013,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -800,7 +1038,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -809,7 +1047,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -1042,7 +1280,7 @@
"[5 rows x 670 columns]"
]
},
- "execution_count": 14,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -1068,7 +1306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.5"
+ "version": "3.8.3"
}
},
"nbformat": 4,
diff --git a/prereise/gather/solardata/nsrdb/naive.py b/prereise/gather/solardata/nsrdb/naive.py
index 5607619b1..b0a4f486d 100644
--- a/prereise/gather/solardata/nsrdb/naive.py
+++ b/prereise/gather/solardata/nsrdb/naive.py
@@ -3,6 +3,7 @@
from tqdm import tqdm
from prereise.gather.solardata.helpers import get_plant_info_unique_location
+from prereise.gather.solardata.nsrdb.nrel_api import NrelApi
def retrieve_data(solar_plant, email, api_key, year="2016"):
@@ -22,24 +23,15 @@ def retrieve_data(solar_plant, email, api_key, year="2016"):
# Identify unique location
coord = get_plant_info_unique_location(solar_plant)
- base_url = "https://developer.nrel.gov/api/solar/nsrdb_psm3_download.csv"
- payload = {
- "api_key": api_key,
- "names": year,
- "leap_day": "true",
- "interval": "60",
- "utc": "true",
- "email": email,
- "attributes": "ghi",
- }
- qs = "&".join([f"{key}={value}" for key, value in payload.items()])
- url = f"{base_url}?{qs}"
+ api = NrelApi(email, api_key)
data = pd.DataFrame({"Pout": [], "plant_id": [], "ts": [], "ts_id": []})
for key in tqdm(coord.keys(), total=len(coord)):
- query = "wkt=POINT({lon}%20{lat})".format(lon=key[0], lat=key[1])
- data_loc = pd.read_csv(f"{url}&{query}", skiprows=2)
+ lat, lon = key[1], key[0]
+ data_loc = api.get_psm3_at(
+ lat, lon, attributes="ghi", year=year, leap_day=True
+ ).data_resource
ghi = data_loc.GHI.values
data_loc = pd.DataFrame({"Pout": ghi})
data_loc["Pout"] /= max(ghi)
diff --git a/prereise/gather/solardata/nsrdb/nrel_api.py b/prereise/gather/solardata/nsrdb/nrel_api.py
new file mode 100644
index 000000000..7c1a9fe83
--- /dev/null
+++ b/prereise/gather/solardata/nsrdb/nrel_api.py
@@ -0,0 +1,143 @@
+from dataclasses import dataclass
+from datetime import timedelta
+
+import pandas as pd
+
+from prereise.gather.request_util import RateLimit, retry
+
+
+@dataclass
+class Psm3Data:
+ """Wrapper class for PSM3 data retrieved from NREL's API. Contains metadata
+ from the first csv row and a data frame representing the remaining time series
+ """
+
+ lat: float
+ lon: float
+ tz: float
+ elevation: float
+ data_resource: pd.DataFrame
+
+ allowed_attrs = {
+ "dn": "DNI",
+ "df": "DHI",
+ "wspd": "Wind Speed",
+ "tdry": "Temperature",
+ "ghi": "GHI",
+ }
+
+ @staticmethod
+ def check_attrs(attributes):
+ for a in attributes.split(","):
+ if a not in allowed_attrs.keys():
+ raise ValueError(f"Unsupported attribute: {a}")
+
+ def to_dict(self):
+ """Convert the data to the format expected by nrel-pysam for running
+ SAM simulations
+
+ :return: (*dict*) -- a dictionary which can be passed to the pvwattsv7
+ module
+ """
+ result = {
+ "lat": self.lat,
+ "lon": self.lon,
+ "tz": self.tz,
+ "elev": self.elevation,
+ "year": self.data_resource.index.year.tolist(),
+ "month": self.data_resource.index.month.tolist(),
+ "day": self.data_resource.index.day.tolist(),
+ "hour": self.data_resource.index.hour.tolist(),
+ "minute": self.data_resource.index.minute.tolist(),
+ }
+ result.update(
+ {
+ k: self.data_resource[v].tolist()
+ for k, v in allowed_attrs.items()
+ if v in self.data_resource.columns
+ }
+ )
+ return result
+
+
+class NrelApi:
+ """Provides an interface to the NREL API for PSM3 data. It supports
+ downloading this data in csv format, which we use to calculate solar output
+ of a set of plants. The user will need to provide an API key.
+ :param str email: email used for API key
+ `sign up `_.
+ :param str api_key: API key.
+ :param int/float rate_limit: minimum seconds to wait between requests to NREL
+ """
+
+ def __init__(self, email, api_key, rate_limit=None):
+ """Constructor"""
+ if email is None:
+ raise ValueError("Email is required")
+ if api_key is None:
+ raise ValueError("API key is required")
+
+ self.email = email
+ self.api_key = api_key
+ self.interval = rate_limit
+
+ def _build_url(self, lat, lon, attributes, year="2016", leap_day=False):
+ """Construct url with formatted query string for downloading psm3
+ (physical solar model) data
+
+ :param str lat: latitude of the plant
+ :param str lon: longitude of the plant
+ :param str attributes: comma separated list of attributes to query
+ :param str year: the year
+ :param bool leap_day: whether to use a leap day
+ :return: (*str*) -- the url to download csv data
+ """
+ base_url = "https://developer.nrel.gov/api/solar/nsrdb_psm3_download.csv"
+ payload = {
+ "api_key": self.api_key,
+ "names": year,
+ "leap_day": str(leap_day).lower(),
+ "interval": "60",
+ "utc": "true",
+ "email": self.email,
+ "attributes": attributes,
+ "wkt": f"POINT({lon}%20{lat})",
+ }
+ query = "&".join([f"{key}={value}" for key, value in payload.items()])
+ return f"{base_url}?{query}"
+
+ def get_psm3_at(self, lat, lon, attributes, year, leap_day, dates=None):
+ """Get PSM3 data at a given point for the specified year.
+
+ :param str lat: latitude of the plant
+ :param str lon: longitude of the plant
+ :param str attributes: comma separated list of attributes to query
+ :param str year: the year
+ :param bool leap_day: whether to use a leap day
+ :param pd.DatetimeIndex dates: if provided, use to index the downloaded data frame
+
+ :return: (*prereise.gather.solardata.nsrdb.nrel_api.Psm3Data*) -- a data class containing metadata and time series for the given year and location
+ """
+ Psm3Data.check_attrs(attributes)
+ url = self._build_url(lat, lon, attributes, year, leap_day)
+
+ @retry(interval=self.interval)
+ def _get_info(url):
+ return pd.read_csv(url, nrows=1)
+
+ @retry(interval=self.interval)
+ def _get_data(url):
+ return pd.read_csv(url, dtype=float, skiprows=2)
+
+ info = _get_info(url)
+ tz, elevation = info["Local Time Zone"], info["Elevation"]
+
+ data_resource = _get_data(url)
+
+ if dates is not None:
+ data_resource.set_index(
+ dates + timedelta(hours=int(tz.values[0])), inplace=True
+ )
+ return Psm3Data(
+ float(lat), float(lon), float(tz), float(elevation), data_resource
+ )
diff --git a/prereise/gather/solardata/nsrdb/sam.py b/prereise/gather/solardata/nsrdb/sam.py
index 1d329f13d..7461226bf 100644
--- a/prereise/gather/solardata/nsrdb/sam.py
+++ b/prereise/gather/solardata/nsrdb/sam.py
@@ -1,5 +1,3 @@
-from datetime import timedelta
-
import numpy as np
import pandas as pd
import PySAM.Pvwattsv7 as PVWatts
@@ -12,13 +10,14 @@
from tqdm import tqdm
from prereise.gather.solardata.helpers import get_plant_info_unique_location
+from prereise.gather.solardata.nsrdb.nrel_api import NrelApi
from prereise.gather.solardata.pv_tracking import (
get_pv_tracking_data,
get_pv_tracking_ratio_state,
)
-def retrieve_data(solar_plant, email, api_key, year="2016"):
+def retrieve_data(solar_plant, email, api_key, year="2016", rate_limit=0.5):
"""Retrieves irradiance data from NSRDB and calculate the power output using
the System Adviser Model (SAM).
@@ -28,6 +27,7 @@ def retrieve_data(solar_plant, email, api_key, year="2016"):
`sign up `_.
:param str api_key: API key.
:param str year: year.
+ :param int/float rate_limit: minimum seconds to wait between requests to NREL
:return: (*pandas.DataFrame*) -- data frame with *'Pout'*, *'plant_id'*,
*'ts'* and *'ts_id'* as columns. The power output is in MWh.
"""
@@ -46,19 +46,6 @@ def retrieve_data(solar_plant, email, api_key, year="2016"):
# Identify unique location
coord = get_plant_info_unique_location(solar_plant)
- base_url = "https://developer.nrel.gov/api/solar/nsrdb_psm3_download.csv"
- payload = {
- "api_key": api_key,
- "names": year,
- "leap_day": "false",
- "interval": "60",
- "utc": "true",
- "email": email,
- "attributes": "dhi,dni,wind_speed,air_temperature",
- }
- qs = "&".join([f"{key}={value}" for key, value in payload.items()])
- url = f"{base_url}?{qs}"
-
data = pd.DataFrame({"Pout": [], "plant_id": [], "ts": [], "ts_id": []})
# PV tracking ratios
@@ -77,37 +64,21 @@ def retrieve_data(solar_plant, email, api_key, year="2016"):
# Inverter Loading Ratio
ilr = 1.25
+ api = NrelApi(email, api_key, rate_limit)
for key in tqdm(coord.keys(), total=len(coord)):
- query = "wkt=POINT({lon}%20{lat})".format(lon=key[0], lat=key[1])
- current_url = f"{url}&{query}"
-
- info = pd.read_csv(current_url, nrows=1)
- tz, elevation = info["Local Time Zone"], info["Elevation"]
-
- data_resource = pd.read_csv(current_url, dtype=float, skiprows=2)
- data_resource.set_index(
- dates + timedelta(hours=int(tz.values[0])), inplace=True
- )
+ lat, lon = key[1], key[0]
+ solar_data = api.get_psm3_at(
+ lat,
+ lon,
+ attributes="dhi,dni,wind_speed,air_temperature",
+ year=year,
+ leap_day=False,
+ dates=dates,
+ ).to_dict()
- # SAM
ssc = pssc.PySSC()
- solar_data = {
- "lat": float(key[1]),
- "lon": float(key[0]),
- "tz": float(tz),
- "elev": float(elevation),
- "year": data_resource.index.year.tolist(),
- "month": data_resource.index.month.tolist(),
- "day": data_resource.index.day.tolist(),
- "hour": data_resource.index.hour.tolist(),
- "minute": data_resource.index.minute.tolist(),
- "dn": data_resource["DNI"].tolist(),
- "df": data_resource["DHI"].tolist(),
- "wspd": data_resource["Wind Speed"].tolist(),
- "tdry": data_resource["Temperature"].tolist(),
- }
for i in coord[key]:
data_site = pd.DataFrame(
{
diff --git a/prereise/gather/tests/__init__.py b/prereise/gather/tests/__init__.py
index f99070801..6352de78c 100644
--- a/prereise/gather/tests/__init__.py
+++ b/prereise/gather/tests/__init__.py
@@ -1 +1,6 @@
-__all__ = ["mock_generation_data_frame", "test_get_monthly_net_generation"]
+__all__ = [
+ "mock_generation_data_frame",
+ "test_get_monthly_net_generation",
+ "test_rate_limit",
+ "test_retry",
+]
diff --git a/prereise/gather/tests/test_rate_limit.py b/prereise/gather/tests/test_rate_limit.py
new file mode 100644
index 000000000..6421dda23
--- /dev/null
+++ b/prereise/gather/tests/test_rate_limit.py
@@ -0,0 +1,37 @@
+import time
+
+import pytest
+
+from prereise.gather.request_util import RateLimit
+
+
+class SleepCounter:
+ def __init__(self):
+ self.time_sleeping = 0
+ self.init_time = time.time()
+
+ def time(self):
+ return self.init_time + self.time_sleeping
+
+ def sleep(self, seconds):
+ self.time_sleeping += seconds
+
+
+@pytest.fixture
+def sleepless(monkeypatch):
+ counter = SleepCounter()
+ monkeypatch.setattr(time, "sleep", counter.sleep)
+ monkeypatch.setattr(time, "time", counter.time)
+ return counter
+
+
+def test_default_no_limit(sleepless):
+ limiter = RateLimit()
+ _ = [limiter.invoke(lambda: "foo") for _ in range(10)]
+ assert sleepless.time_sleeping == 0
+
+
+def test_sleep_occurrs(sleepless):
+ limiter = RateLimit(24)
+ _ = [limiter.invoke(lambda: "foo") for _ in range(10)]
+ assert sleepless.time_sleeping >= 240 - 24 # no sleep on first iteration
diff --git a/prereise/gather/tests/test_retry.py b/prereise/gather/tests/test_retry.py
new file mode 100644
index 000000000..8cfd36211
--- /dev/null
+++ b/prereise/gather/tests/test_retry.py
@@ -0,0 +1,43 @@
+import pytest
+
+from prereise.gather.request_util import retry
+
+
+class CustomException(Exception):
+ pass
+
+
+def test_max_times_reached():
+ @retry(retry_count=8, allowed_exceptions=CustomException)
+ def no_fail(x=[]):
+ x.append(len(x))
+ raise CustomException()
+
+ counts = []
+ no_fail(counts)
+ assert len(counts) == 8
+
+
+def test_return_value():
+ @retry()
+ def return_something():
+ return 42
+
+ assert 42 == return_something()
+
+
+def test_decorate_without_call():
+ @retry
+ def still_works():
+ return 42
+
+ assert 42 == still_works()
+
+
+def test_unhandled_exception():
+ @retry()
+ def fail():
+ raise Exception()
+
+ with pytest.raises(Exception):
+ fail()