diff --git a/conftest.py b/conftest.py index b3c44f06..e5c24250 100644 --- a/conftest.py +++ b/conftest.py @@ -36,8 +36,8 @@ def sat_data_source(sat_filename: Path): return SatelliteDataSource( image_size_pixels=pytest.IMAGE_SIZE_PIXELS, filename=sat_filename, - history_len=0, - forecast_len=1, + history_minutes=0, + forecast_minutes=5, channels=('HRV', ), n_timesteps_per_batch=2, convert_to_numpy=True) diff --git a/notebooks/2021-08-16/Sun Angles.ipynb b/notebooks/2021-08/2021-08-16/Sun Angles.ipynb similarity index 100% rename from notebooks/2021-08-16/Sun Angles.ipynb rename to notebooks/2021-08/2021-08-16/Sun Angles.ipynb diff --git a/notebooks/2021-08-16/angles.png b/notebooks/2021-08/2021-08-16/angles.png similarity index 100% rename from notebooks/2021-08-16/angles.png rename to notebooks/2021-08/2021-08-16/angles.png diff --git a/notebooks/2021-08-20/data_exploration.ipynb b/notebooks/2021-08/2021-08-20/data_exploration.ipynb similarity index 100% rename from notebooks/2021-08-20/data_exploration.ipynb rename to notebooks/2021-08/2021-08-20/data_exploration.ipynb diff --git a/notebooks/2021-08-20/data_exploration2.ipynb b/notebooks/2021-08/2021-08-20/data_exploration2.ipynb similarity index 100% rename from notebooks/2021-08-20/data_exploration2.ipynb rename to notebooks/2021-08/2021-08-20/data_exploration2.ipynb diff --git a/notebooks/2021-08-20/staticmap.py b/notebooks/2021-08/2021-08-20/staticmap.py similarity index 100% rename from notebooks/2021-08-20/staticmap.py rename to notebooks/2021-08/2021-08-20/staticmap.py diff --git a/notebooks/2021-08-25/video.py b/notebooks/2021-08/2021-08-25/video.py similarity index 99% rename from notebooks/2021-08-25/video.py rename to notebooks/2021-08/2021-08-25/video.py index 26e05c78..51544767 100644 --- a/notebooks/2021-08-25/video.py +++ b/notebooks/2021-08/2021-08-25/video.py @@ -8,7 +8,7 @@ import cv2 DATA_PATH = "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v4/" -TEMP_PATH = "." +TEMP_PATH = "" # set up data generator train_dataset = NetCDFDataset(24_900, os.path.join(DATA_PATH, "train"), os.path.join(TEMP_PATH, "train")) diff --git a/notebooks/2021-08-26/video.py b/notebooks/2021-08/2021-08-26/video.py similarity index 99% rename from notebooks/2021-08-26/video.py rename to notebooks/2021-08/2021-08-26/video.py index e49f29e5..2da61ce0 100644 --- a/notebooks/2021-08-26/video.py +++ b/notebooks/2021-08/2021-08-26/video.py @@ -19,7 +19,7 @@ ############## DATA_PATH = "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v4/" -TEMP_PATH = "." +TEMP_PATH = "" # set up data generator train_dataset = NetCDFDataset(24_900, os.path.join(DATA_PATH, "train"), os.path.join(TEMP_PATH, "train")) diff --git a/notebooks/2021-09/2021-09-01/GSP data.ipynb b/notebooks/2021-09/2021-09-01/GSP data.ipynb new file mode 100644 index 00000000..9ab5d253 --- /dev/null +++ b/notebooks/2021-09/2021-09-01/GSP data.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3407b6ba", + "metadata": {}, + "outputs": [], + "source": [ + "# get data\n", + "import urllib\n", + "import json\n", + "url = 'https://data.nationalgrideso.com/api/3/action/datastore_search?resource_id=bbe2cc72-a6c6-46e6-8f4e-48b879467368&limit=400'\n", + "fileobj = urllib.request.urlopen(url)\n", + "d = json.loads(fileobj.read())\n", + "\n", + "print(d)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a58b2c7", + "metadata": {}, + "outputs": [], + "source": [ + "# load data\n", + "import pandas as pd\n", + "results = d['result']['records']\n", + "\n", + "data_df = pd.DataFrame(results)\n", + "print(len(data_df))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb615160", + "metadata": {}, + "outputs": [], + "source": [ + "# plot data\n", + "\n", + "import plotly.graph_objects as go\n", + "\n", + "fig = go.Figure(data=go.Scatter(x=data_df['gsp_lon'], y=data_df['gsp_lat'], mode='markers'))\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddfc9698", + "metadata": {}, + "outputs": [], + "source": [ + "# plot on static map\n", + "\n", + "import staticmaps\n", + "from staticmaps.marker import Marker\n", + "\n", + "\n", + "context = staticmaps.Context()\n", + "\n", + "# make bottom left and top right bounding box for map, happens to be Oxford and Norwich\n", + "bottom_left = staticmaps.create_latlng(50, -8)\n", + "top_right = staticmaps.create_latlng(59, 3)\n", + "\n", + "for i in range(len(data_df)-4):\n", + " row = data_df.iloc[i]\n", + " context.add_object(Marker(staticmaps.create_latlng(row.gnode_lat, row.gnode_lon), size=3))\n", + " \n", + " \n", + "# make clean map\n", + "m = context.make_clean_map_from_bounding_box(bottom_left=bottom_left, top_right=top_right, width=1000, height=1000)\n", + "\n", + "m.show()\n", + "m.save('GSP.png')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5e08e2d", + "metadata": {}, + "outputs": [], + "source": [ + "# plot using plotly\n", + "import numpy as np\n", + "import plotly.graph_objects as go\n", + "\n", + "m = np.array(m)\n", + "print(m.shape)\n", + "\n", + "x = (data_df['gnode_lon'] +8) / 11 * m.shape[1]\n", + "y = (59-data_df['gnode_lat']) / 9 * m.shape[0]\n", + "\n", + "trace_map = go.Image(z=m)\n", + "\n", + "layout = go.Layout(\n", + " paper_bgcolor='rgba(0,0,0,0)',\n", + " plot_bgcolor='rgba(0,0,0,0)'\n", + ")\n", + "\n", + "fig = go.Figure(data=[trace_map, go.Scatter(x=x, y=y, mode='markers', text=data_df['gnode_name'], \n", + " marker=dict(color='LightSkyBlue'))],\n", + " layout=layout)\n", + "\n", + "fig.update_yaxes(showticklabels=False)\n", + "fig.update_xaxes(showticklabels=False)\n", + "\n", + "\n", + "\n", + "fig.show()\n", + "\n", + "### This doesnt quite work, due to curitual of the world\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4bffdb2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2548a7da", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ad608b5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9087e598", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2021-09/2021-09-01/GSP solar data.ipynb b/notebooks/2021-09/2021-09-01/GSP solar data.ipynb new file mode 100644 index 00000000..f3efadac --- /dev/null +++ b/notebooks/2021-09/2021-09-01/GSP solar data.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6e5860ea", + "metadata": {}, + "outputs": [], + "source": [ + "# get gsp metadata\n", + "import urllib\n", + "import json\n", + "import pandas as pd\n", + "\n", + "from pvlive_api import PVLive\n", + "from datetime import datetime, timedelta\n", + "import pytz\n", + "import plotly.graph_objects as go\n", + "\n", + "# call ESO website\n", + "url = 'https://data.nationalgrideso.com/api/3/action/datastore_search?resource_id=bbe2cc72-a6c6-46e6-8f4e-48b879467368&limit=400'\n", + "fileobj = urllib.request.urlopen(url)\n", + "d = json.loads(fileobj.read())\n", + "\n", + "# make dataframe\n", + "results = d['result']['records']\n", + "metadata_df = pd.DataFrame(results)\n", + "\n", + "print(metadata_df)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2ab08c5", + "metadata": {}, + "outputs": [], + "source": [ + "# plot one day of data\n", + "# https://github.com/SheffieldSolar/PV_Live-API, use this repo\n", + "\n", + "pvl = PVLive()\n", + "\n", + "# test getting some data\n", + "start = datetime(2019,1,1,tzinfo=pytz.utc)\n", + "end = datetime(2019,1,2,tzinfo=pytz.utc)\n", + "\n", + "one_day_gsp_data_df = pvl.between(start, end, entity_type=\"gsp\", entity_id=0, extra_fields=\"\", dataframe=True)\n", + "\n", + "one_day_gsp_data_df = one_day_gsp_data_df.sort_values(by=['datetime_gmt'])\n", + "\n", + "fig = go.Figure(data=go.Scatter(x=one_gsp_data_df['datetime_gmt'], y=one_gsp_data_df['generation_mw']))\n", + "fig.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52f31238", + "metadata": {}, + "outputs": [], + "source": [ + "# plot one month of data\n", + "\n", + "pvl = PVLive()\n", + "\n", + "# test getting some data\n", + "start = datetime(2019,1,1,tzinfo=pytz.utc)\n", + "end = datetime(2019,2,1, tzinfo=pytz.utc)\n", + "\n", + "one_month_gsp_data_df = pvl.between(start, end, entity_type=\"gsp\", entity_id=0, extra_fields=\"\", dataframe=True)\n", + "\n", + "one_month_gsp_data_df = one_month_gsp_data_df.sort_values(by=['datetime_gmt'])\n", + "\n", + "fig = go.Figure(data=go.Scatter(x=one_month_gsp_data_df['datetime_gmt'], y=one_month_gsp_data_df['generation_mw']))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4107dab0", + "metadata": {}, + "outputs": [], + "source": [ + "# plot one month of data, for 10 sites\n", + "from datetime import datetime, timedelta\n", + "pvl = PVLive()\n", + "\n", + "# test getting some data\n", + "start = datetime(2019,1,1,tzinfo=pytz.utc)\n", + "end = datetime(2019,6,1, tzinfo=pytz.utc)\n", + "\n", + "N_gsp_ids = 20\n", + "\n", + "one_month_gsp_data_df = []\n", + "for i in range(0,20):\n", + " start_chunk = start\n", + " end_chunk = start_chunk + timedelta(days=30)\n", + " while end_chunk < end:\n", + " print(f'Getting data for id {i} from {start_chunk} to {end_chunk}')\n", + " one_month_gsp_data_df.append(pvl.between(start=start_chunk, \n", + " end=end_chunk, \n", + " entity_type=\"gsp\", \n", + " entity_id=i, \n", + " extra_fields=\"\", \n", + " dataframe=True))\n", + " \n", + " start_chunk = start_chunk + timedelta(days=30)\n", + " end_chunk = end_chunk + timedelta(days=30)\n", + " \n", + " \n", + "one_month_gsp_data_df = pd.concat(one_month_gsp_data_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3520f01e", + "metadata": {}, + "outputs": [], + "source": [ + "one_month_gsp_data_df = one_month_gsp_data_df.sort_values(by=['gsp_id','datetime_gmt'])\n", + "\n", + "fig = go.Figure()\n", + "for i in range(0,N_gsp_ids):\n", + " temp_df = one_month_gsp_data_df[one_month_gsp_data_df['gsp_id'] == i]\n", + " \n", + " fig.add_trace(go.Scatter(x=temp_df['datetime_gmt'], \n", + " y=temp_df['generation_mw'], name=metadata_df.loc[i].gnode_name))\n", + "\n", + "fig.update_layout(title='GSP solar data')\n", + "fig.update_yaxes(title='MW')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6744491f", + "metadata": {}, + "outputs": [], + "source": [ + "print(metadata_df.loc[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00e093dd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2021-09/2021-09-01/GSP.png b/notebooks/2021-09/2021-09-01/GSP.png new file mode 100644 index 00000000..a378b11d Binary files /dev/null and b/notebooks/2021-09/2021-09-01/GSP.png differ diff --git a/notebooks/2021-09/2021-09-02/GSP Data Analysis.ipynb b/notebooks/2021-09/2021-09-02/GSP Data Analysis.ipynb new file mode 100644 index 00000000..d5022005 --- /dev/null +++ b/notebooks/2021-09/2021-09-02/GSP Data Analysis.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7d30935b", + "metadata": {}, + "outputs": [], + "source": [ + "from nowcasting_dataset.data_sources.pv_gsp_data_source import load_solar_pv_gsp_data_from_gcs\n", + "import plotly.graph_objects as go\n", + "from datetime import datetime\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cea418b", + "metadata": {}, + "outputs": [], + "source": [ + "# lets load the data from gcp\n", + "# filename = \"/Users/peterdudfield/Documents/Github/nowcasting_dataset/notebooks/2021-09/2021-09-02/pv_gsp.zarr\"\n", + "filename = \"gs://solar-pv-nowcasting-data/PV/GSP/v0/pv_gsp.zarr/\"\n", + "\n", + "data = load_solar_pv_gsp_data_from_gcs(\n", + " from_gcs=False,\n", + " filename=filename,\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0662e70c", + "metadata": {}, + "outputs": [], + "source": [ + "# plot first 10 systems\n", + "fig = go.Figure()\n", + "for i in range(1, 10):\n", + " fig.add_trace(go.Scatter(x=data.index, y=data[i]))\n", + "fig.update_layout(\n", + " title=\"GSP PV of 10 systems\",\n", + " yaxis_title=\"GSP PV [MW]\",\n", + " xaxis_title=\"Time\",\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b135024a", + "metadata": {}, + "outputs": [], + "source": [ + "# lets pick out one day and plot all the systems in that day\n", + "start_dt = datetime(2019, 4, 1)\n", + "end_dt = datetime(2019, 4, 2)\n", + "data_one_day = data[(data.index <= end_dt) & (data.index >= start_dt)]\n", + "\n", + "# plot\n", + "fig = go.Figure()\n", + "for col in data_one_day.columns:\n", + " fig.add_trace(go.Scatter(x=data_one_day.index, y=data_one_day[col]))\n", + "fig.update_layout(\n", + " title=\"GSP PV on 2019-04-01\",\n", + " yaxis_title=\"GSP PV [MW]\",\n", + " xaxis_title=\"Time\",\n", + ")\n", + "fig.show()\n", + "# shows one day, with the max about 350 MW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e638d0cc", + "metadata": {}, + "outputs": [], + "source": [ + "# lets look at the distributions of the peaks on that day\n", + "max_pv = data_one_day.max()\n", + "fig = go.Figure(data=[go.Histogram(x=max_pv)])\n", + "fig.update_layout(\n", + " title=\"Historgram of max GSP PV on 2019-04-01\",\n", + " xaxis_title=\"GSP PV [MW]\",\n", + " yaxis_title=\"count\",\n", + ")\n", + "fig.show()\n", + "\n", + "# 60% of gsp systems are producing less than 5 MW\n", + "\n", + "# lets plot cdf\n", + "from statsmodels.distributions.empirical_distribution import ECDF\n", + "\n", + "fig = go.Figure()\n", + "fig.add_scatter(x=np.unique(max_pv), y=ECDF(max_pv)(np.unique(max_pv)), line_shape='hv')\n", + "fig.update_layout(\n", + " title=\"CDF of max GSP PV on 2019-04-01\",\n", + " xaxis_title=\"GSP PV [MW]\",\n", + " yaxis_title=\"CDF\",\n", + ")\n", + "fig.show()\n", + "\n", + "# 60% of gsp systems are producing less than 5 MW\n", + "# 70% of gsp systems are producing less than 10 MW\n", + "# 80% of gsp systems are producing less than 36 MW\n", + "# 90% of gsp systems are producing less than 78 MW\n", + "# means 10 % of gsp systems ~38 produce around 8000 MW, average of ~200MW each\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "853d671a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2021-09/2021-09-06/gsp.py b/notebooks/2021-09/2021-09-06/gsp.py new file mode 100644 index 00000000..e1058996 --- /dev/null +++ b/notebooks/2021-09/2021-09-06/gsp.py @@ -0,0 +1,42 @@ +import json +from urllib.request import urlopen +import geopandas as gpd +import plotly.graph_objects as go + +WGS84_CRS = "EPSG:4326" + +# get file +url = ( + "https://data.nationalgrideso.com/backend/dataset/2810092e-d4b2-472f-b955-d8bea01f9ec0/resource/" + "a3ed5711-407a-42a9-a63a-011615eea7e0/download/gsp_regions_20181031.geojson" +) + +with urlopen(url) as response: + shapes_gdp = gpd.read_file(response).to_crs(WGS84_CRS) + +# set z axis +shapes_gdp["Amount"] = 0 + +# get dict shapes +shapes_dict = json.loads(shapes_gdp.to_json()) + +# plot it +fig = go.Figure() +fig.add_trace( + go.Choroplethmapbox(geojson=shapes_dict, locations=shapes_gdp.index, z=shapes_gdp.Amount, colorscale="Viridis") +) + +fig.update_layout(mapbox_style="carto-positron", mapbox_zoom=4, mapbox_center={"lat": 55, "lon": 0}) +fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0}) +fig.show(renderer="browser") +fig.write_html('gsp.html') + + +# find out if point is in gsp +from shapely.geometry import Point, Polygon + +_pnts = [Point(3, 3), Point(8, 8), Point(0, 51.38)] +pnts = gpd.GeoDataFrame(geometry=_pnts, index=['A', 'B', 'C']) + +# useful way to see if a point is in a polygon +pnts = pnts.assign(**{str(key): pnts.within(geom['geometry']) for key, geom in shapes_gdp.iterrows()}) diff --git a/notebooks/2021-09/2021-09-07/gsp.py b/notebooks/2021-09/2021-09-07/gsp.py new file mode 100644 index 00000000..b4fd2820 --- /dev/null +++ b/notebooks/2021-09/2021-09-07/gsp.py @@ -0,0 +1,68 @@ +import json +from urllib.request import urlopen +import geopandas as gpd +import plotly.graph_objects as go +import numpy as np +import itertools + +WGS84_CRS = "EPSG:4326" + +# get file +url = ( + "https://data.nationalgrideso.com/backend/dataset/2810092e-d4b2-472f-b955-d8bea01f9ec0/resource/" + "a3ed5711-407a-42a9-a63a-011615eea7e0/download/gsp_regions_20181031.geojson" +) + +with urlopen(url) as response: + shapes_gdp = gpd.read_file(response).to_crs(WGS84_CRS) + +# set z axis +shapes_gdp["Amount"] = 0 + +# get dict shapes +shapes_dict = json.loads(shapes_gdp.to_json()) + +# plot it +fig = go.Figure() +fig.add_trace( + go.Choroplethmapbox(geojson=shapes_dict, locations=shapes_gdp.index, z=shapes_gdp.Amount, colorscale="Viridis") +) + +fig.update_layout(mapbox_style="carto-positron", mapbox_zoom=4, mapbox_center={"lat": 55, "lon": 0}) +fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0}) +fig.show(renderer="browser") +fig.write_html('gsp.html') + + +# find out if point is in gsp +from shapely.geometry import Point, Polygon + +_pnts = [Point(3, 3), Point(8, 8), Point(0, 51.38)] +pnts = gpd.GeoDataFrame(geometry=_pnts, index=['A', 'B', 'C']) + +# useful way to see if a point is in a polygon +pnts = pnts.assign(**{str(key): pnts.within(geom['geometry']) for key, geom in shapes_gdp.iterrows()}) + +# create lat and long array, same system as Example, +sat_x_coords = np.array(range(480000, 600000, 2000)) / 10000 +sat_y_coords = np.array(range(330000, 200000, -2000)) / 10000 + +sat_x_y_coords = np.zeros((sat_x_coords.size, sat_y_coords.size,2)) +_pnts = [] +for i, j in itertools.product(range(sat_x_coords.size), range(sat_y_coords.size)): + _pnts.append(Point(sat_x_coords[i], sat_y_coords[j])) + pnts = gpd.GeoDataFrame(geometry=_pnts) + sat_x_y_coords[i,j] = np.array([sat_x_coords[i], sat_y_coords[j]]) + +pnts = pnts.assign(**{str(key): pnts.within(geom['geometry']) for key, geom in shapes_gdp.iterrows()}) + + + + + + + + + + + diff --git a/notebooks/2021-09/2021-09-07/sat_data.py b/notebooks/2021-09/2021-09-07/sat_data.py new file mode 100644 index 00000000..79d28f52 --- /dev/null +++ b/notebooks/2021-09/2021-09-07/sat_data.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from nowcasting_dataset.data_sources.satellite_data_source import SatelliteDataSource + +s = SatelliteDataSource( + filename="gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/" + "all_zarr_int16_single_timestep.zarr", + history_len=6, + forecast_len=12, + convert_to_numpy=True, + image_size_pixels=64, + meters_per_pixel=2000, + n_timesteps_per_batch=32, +) + +s.open() +start_dt = datetime.fromisoformat("2019-01-01 00:00:00.000+00:00") +end_dt = datetime.fromisoformat("2019-01-02 00:00:00.000+00:00") + +data_xarray = s._data +data_xarray = data_xarray.sel(time=slice(start_dt, end_dt)) +data_xarray = data_xarray.sel(variable=["HRV"]) +data_xarray = data_xarray.sel(x=slice(122000, 122001)) + +data_df = data_xarray.to_dataframe() diff --git a/notebooks/2021-09/2021-09-13/remove_hash.py b/notebooks/2021-09/2021-09-13/remove_hash.py new file mode 100644 index 00000000..e516eedc --- /dev/null +++ b/notebooks/2021-09/2021-09-13/remove_hash.py @@ -0,0 +1,36 @@ +from nowcasting_dataset.cloud.gcp import get_all_filenames_in_path, rename_file +import random + + +# batch files are save with hash id and then batch and we want to remove the hashing +# {xxxxxx}_{batch_idx}.nc is the format of the file + +train_path = 'gs://solar-pv-nowcasting-data/prepared_ML_training_data/v5/train/' +validation_path = 'gs://solar-pv-nowcasting-data/prepared_ML_training_data/v5/validation/' + +train_filenames = get_all_filenames_in_path(remote_path=train_path)[1:] +validation_filenames = get_all_filenames_in_path(remote_path=validation_path)[1:] + +random.shuffle(train_filenames) +random.shuffle(validation_filenames) + +train_filenames = [file for file in train_filenames if '_' in file.split('/')[-1]] +validation_filenames = [file for file in validation_filenames if '_' in file.split('/')[-1]] + + +for filenames in [train_filenames, validation_filenames]: + for file in train_filenames: + + print(file) + + filename = file.split('/')[-1] + if '_' in filename: + path = '/'.join(file.split('/')[:-1]) + '/' + new_filename = path + filename.split('_')[-1] + + try: + rename_file(remote_file=file, new_filename=new_filename) + except Exception as e: + pass + else: + print(f'Skipping {filename}') \ No newline at end of file diff --git a/nowcasting_dataset/cloud/gcp.py b/nowcasting_dataset/cloud/gcp.py index 3c87e8e1..9751fa00 100644 --- a/nowcasting_dataset/cloud/gcp.py +++ b/nowcasting_dataset/cloud/gcp.py @@ -4,7 +4,7 @@ import gcsfs -from nowcasting_dataset.cloud.local import delete_all_files_in_temp_path +from nowcasting_dataset.cloud.local import delete_all_files_and_folder_in_temp_path _LOG = logging.getLogger(__name__) @@ -26,7 +26,7 @@ def gcp_upload_and_delete_local_files(dst_path: str, local_path: Path): _LOG.info("Uploading to GCS!") gcs = gcsfs.GCSFileSystem() gcs.put(str(local_path), dst_path, recursive=True) - delete_all_files_in_temp_path(local_path) + delete_all_files_and_folder_in_temp_path(local_path) def gcp_download_to_local(remote_filename: str, local_filename: str, gcs: gcsfs.GCSFileSystem = None): @@ -53,3 +53,17 @@ def get_all_filenames_in_path(remote_path) -> List[str]: gcs = gcsfs.GCSFileSystem() return gcs.ls(remote_path) + + +def rename_file(remote_file: str, new_filename: str): + """ + Rename file + + Args: + remote_file: The file name in gcs + new_filename: What the file should be renamed too + + """ + gcs = gcsfs.GCSFileSystem() + + gcs.mv(remote_file, new_filename) diff --git a/nowcasting_dataset/config/example.yaml b/nowcasting_dataset/config/example.yaml index 2c5cb24a..fafec554 100644 --- a/nowcasting_dataset/config/example.yaml +++ b/nowcasting_dataset/config/example.yaml @@ -12,8 +12,8 @@ output_data: filepath: prepared_ML_training_data/v4/ process: batch_size: 32 - forecast_length: 12 - history_length: 6 + forecast_minutes: 60 + history_minutes: 30 image_size_pixels: 64 nwp_channels: - t diff --git a/nowcasting_dataset/config/gcp.yaml b/nowcasting_dataset/config/gcp.yaml new file mode 100644 index 00000000..bf302f18 --- /dev/null +++ b/nowcasting_dataset/config/gcp.yaml @@ -0,0 +1,43 @@ +general: + description: example configuration + name: example +input_data: + bucket: solar-pv-nowcasting-data + npw_base_path: NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr + satelite_filename: satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr + solar_pv_data_filename: UK_PV_timeseries_batch.nc + solar_pv_metadata_filename: UK_PV_metadata.csv + solar_pv_path: PV/PVOutput.org +output_data: + filepath: solar-pv-nowcasting-data/prepared_ML_training_data/v5/ +process: + batch_size: 32 + forecast_minutes: 60 + history_minutes: 30 + image_size_pixels: 64 + nwp_channels: + - t + - dswrf + - prate + - r + - sde + - si10 + - vis + - lcc + - mcc + - hcc + prcesion: 16 + sat_channels: + - HRV + - IR_016 + - IR_039 + - IR_087 + - IR_097 + - IR_108 + - IR_120 + - IR_134 + - VIS006 + - VIS008 + - WV_062 + - WV_073 + val_check_interval: 1000 diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 5d0084f2..6d745397 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -26,18 +26,20 @@ class InputData(BaseModel): description="TODO", ) + gsp_filename: str = Field("PV/GSP/v0/pv_gsp.zarr") + class OutputData(BaseModel): - filepath: str = Field("prepared_ML_training_data/v4/", description="Where the data is saved") + filepath: str = Field("prepared_ML_training_data/v5/", description="Where the data is saved") class Process(BaseModel): batch_size: int = Field(32, description="the batch size of the data") - forecast_length: int = Field(12, description="how many time steps to forecast in the future") - history_length: int = Field(6, description="how many historic times teps are used") - image_size_pixels: int = Field(64, description="the size of the satelite images") + forecast_minutes: int = Field(60, description="how many minutes to forecast in the future") + history_minutes: int = Field(30, description="how many historic minutes are used") + image_size_pixels: int = Field(64, description="the size of the satellite images") - sat_channels: tuple = Field(SAT_VARIABLE_NAMES, description="the satelite channels that are used") + sat_channels: tuple = Field(SAT_VARIABLE_NAMES, description="the satellite channels that are used") nwp_channels: tuple = Field(NWP_VARIABLE_NAMES, description="the channels used in the nwp data") precision: int = Field(16, description="what precision to use") diff --git a/nowcasting_dataset/consts.py b/nowcasting_dataset/consts.py index 6ec4105d..1cf8e4f1 100644 --- a/nowcasting_dataset/consts.py +++ b/nowcasting_dataset/consts.py @@ -22,3 +22,19 @@ # Typing Array = Union[xr.DataArray, np.ndarray] +PV_SYSTEM_ID: str = 'pv_system_id' +PV_SYSTEM_ROW_NUMBER = 'pv_system_row_number' +PV_SYSTEM_X_COORDS = 'pv_system_x_coords' +PV_SYSTEM_Y_COORDS = 'pv_system_y_coords' +PV_AZIMUTH_ANGLE = 'pv_azimuth_angle' +PV_ELEVATION_ANGLE = 'pv_elevation_angle' +PV_YIELD = 'pv_yield' +DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE = 128 +GSP_ID: str = "gsp_id" +GSP_YIELD = "gsp_yield" +GSP_X_COORDS = "gsp_x_coords" +GSP_Y_COORDS = "gsp_y_coords" +GSP_DATETIME_INDEX = "gsp_datetime_index" +DEFAULT_N_GSP_PER_EXAMPLE = 32 +OBJECT_AT_CENTER = "object_at_center" +DATETIME_FEATURE_NAMES = ("hour_of_day_sin", "hour_of_day_cos", "day_of_year_sin", "day_of_year_cos") \ No newline at end of file diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 2154b807..2d48c9d6 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -1,13 +1,16 @@ from numbers import Number import pandas as pd import numpy as np -from nowcasting_dataset.example import Example, to_numpy +from nowcasting_dataset.dataset.example import Example, to_numpy from nowcasting_dataset import square import nowcasting_dataset.time as nd_time from dataclasses import dataclass, InitVar from typing import List, Tuple, Iterable import xarray as xr import itertools +import logging + +logger = logging.getLogger(__name__) @dataclass @@ -15,26 +18,41 @@ class DataSource: """Abstract base class. Attributes: - history_len: Number of timesteps of history to include in each example. + history_minutes: Number of minutes of history to include in each example. Does NOT include t0. That is, if history_len = 0 then the example will start at t0. - forecast_len: Number of timesteps of forecast to include in each example. + forecast_minutes: Number of minutes of forecast to include in each example. Does NOT include t0. If forecast_len = 0 then the example will end at t0. If both history_len and forecast_len are 0, then the example will consist of a single timestep at t0. convert_to_numpy: Whether or not to convert each example to numpy. + sample_period_minutes: The time delta between each data point """ - history_len: int - forecast_len: int + + history_minutes: int + forecast_minutes: int convert_to_numpy: bool def __post_init__(self): + + self.sample_period_minutes = self._get_sample_period_minutes() + + self.history_len = self.history_minutes // self.sample_period_minutes + self.forecast_len = self.forecast_minutes // self.sample_period_minutes + assert self.history_len >= 0 assert self.forecast_len >= 0 + assert self.history_minutes % self.sample_period_minutes == 0, \ + f'sample period ({self.sample_period_minutes}) minutes ' \ + f'does not fit into historic minutes ({self.forecast_minutes})' + assert self.forecast_minutes % self.sample_period_minutes == 0, \ + f'sample period ({self.sample_period_minutes}) minutes ' \ + f'does not fit into forecast minutes ({self.forecast_minutes})' + # Plus 1 because neither history_len nor forecast_len include t0. self._total_seq_len = self.history_len + self.forecast_len + 1 - self._history_dur = nd_time.timesteps_to_duration(self.history_len) - self._forecast_dur = nd_time.timesteps_to_duration(self.forecast_len) + self._history_dur = nd_time.timesteps_to_duration(self.history_len, self.sample_period_minutes) + self._forecast_dur = nd_time.timesteps_to_duration(self.forecast_len, self.sample_period_minutes) def _get_start_dt(self, t0_dt: pd.Timestamp) -> pd.Timestamp: return t0_dt - self._history_dur @@ -43,6 +61,15 @@ def _get_end_dt(self, t0_dt: pd.Timestamp) -> pd.Timestamp: return t0_dt + self._forecast_dur # ************* METHODS THAT CAN BE OVERRIDDEN **************************** + def _get_sample_period_minutes(self): + """ + This is the default sample period in minutes. This functions may be overwritten if + the sample period of the data source is not 5 minutes + """ + logging.debug('Getting sample_period_minutes default of 5 minutes. ' + 'This means the data is spaced 5 minutes apart') + return 5 + def open(self): """Open the data source, if necessary. @@ -185,8 +212,9 @@ def get_example( f'x_meters_center={x_meters_center}\n' f'y_meters_center={y_meters_center}\n' f't0_dt={t0_dt}\n' + f'times are {selected_data.time}\n' f'expected shape={self._shape_of_example}\n' - f'actual shape {selected_data.shape}') + f'actual shape {selected_data.shape}') return self._put_data_into_example(selected_data) diff --git a/nowcasting_dataset/data_sources/datetime_data_source.py b/nowcasting_dataset/data_sources/datetime_data_source.py index 3e7837af..6c8f3138 100644 --- a/nowcasting_dataset/data_sources/datetime_data_source.py +++ b/nowcasting_dataset/data_sources/datetime_data_source.py @@ -1,5 +1,5 @@ from nowcasting_dataset.data_sources.data_source import DataSource -from nowcasting_dataset.example import Example +from nowcasting_dataset.dataset.example import Example from nowcasting_dataset import time as nd_time from dataclasses import dataclass import pandas as pd diff --git a/nowcasting_dataset/data_sources/gsp/__init__.py b/nowcasting_dataset/data_sources/gsp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nowcasting_dataset/data_sources/gsp/eso.py b/nowcasting_dataset/data_sources/gsp/eso.py new file mode 100644 index 00000000..307bbd00 --- /dev/null +++ b/nowcasting_dataset/data_sources/gsp/eso.py @@ -0,0 +1,100 @@ +""" +This file has a few functions that are used to get GSP (Grid Supply Point) information from National Grid ESO. +ESO - Electricity System Operator. General information can be found here +- https://data.nationalgrideso.com/system/gis-boundaries-for-gb-grid-supply-points + +get_gsp_metadata_from_eso: gets the gsp metadata +get_gsp_shape_from_eso: gets the shape of the gsp regions +get_list_of_gsp_ids: gets a list of gsp_ids, by using 'get_gsp_metadata_from_eso' + +Peter Dudfield +2021-09-13 +""" + +import json +import urllib +import logging +from typing import List, Optional +from urllib.request import urlopen + +import geopandas as gpd +import pandas as pd + +from nowcasting_dataset.geospatial import WGS84_CRS + +logger = logging.getLogger(__name__) + + +def get_gsp_metadata_from_eso() -> pd.DataFrame: + """ + Get the metadata for the gsp, from ESO. + @return: + """ + + # call ESO website. There is a possibility that this API will be replaced and its unclear if this original API will + # will stay operational + url = ( + "https://data.nationalgrideso.com/api/3/action/datastore_search?" + "resource_id=bbe2cc72-a6c6-46e6-8f4e-48b879467368&limit=400" + ) + fileobj = urllib.request.urlopen(url) + d = json.loads(fileobj.read()) + + # make dataframe + results = d["result"]["records"] + metadata = pd.DataFrame(results) + + # drop duplicates + return metadata.drop_duplicates(subset=['gsp_id']) + + +def get_gsp_shape_from_eso() -> gpd.GeoDataFrame: + """ + Get the the gsp shape file + """ + + logger.debug('Loading GSP shape file') + + # call ESO website. There is a possibility that this API will be replaced and its unclear if this original API will + # will stay operational + url = ( + "https://data.nationalgrideso.com/backend/dataset/2810092e-d4b2-472f-b955-d8bea01f9ec0/resource/" + "a3ed5711-407a-42a9-a63a-011615eea7e0/download/gsp_regions_20181031.geojson" + ) + + with urlopen(url) as response: + return gpd.read_file(response).to_crs(WGS84_CRS) + + +def get_list_of_gsp_ids(maximum_number_of_gsp: Optional[int] = None) -> List[int]: + """ + Get list of gsp ids from ESO metadata + + Args: + maximum_number_of_gsp: Truncate list of GSPs to be no larger than this number of GSPs. + Set to None to disable truncation. + + Returns: list of gsp ids + + """ + + # get a lit of gsp ids + metadata = get_gsp_metadata_from_eso() + + # get rid of nans, and duplicates + metadata = metadata[~metadata['gsp_id'].isna()] + metadata.drop_duplicates(subset=['gsp_id'], inplace=True) + + # make into list + gsp_ids = metadata['gsp_id'].to_list() + gsp_ids = [int(gsp_id) for gsp_id in gsp_ids] + + # adjust number of gsp_ids + if maximum_number_of_gsp is None: + maximum_number_of_gsp = len(metadata) + if maximum_number_of_gsp > len(metadata): + logging.warning(f'Only {len(metadata)} gsp available to load') + if maximum_number_of_gsp < len(metadata): + gsp_ids = gsp_ids[0: maximum_number_of_gsp] + + return gsp_ids \ No newline at end of file diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py new file mode 100644 index 00000000..4895c4fc --- /dev/null +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -0,0 +1,372 @@ +import logging + +import xarray as xr + +from typing import Union, Optional, Tuple, List +from pathlib import Path +from datetime import datetime +from dataclasses import dataclass +from numbers import Number +import torch +import numpy as np +import pandas as pd + +from nowcasting_dataset.utils import scale_to_0_to_1, pad_data +from nowcasting_dataset.square import get_bounding_box_mask +from nowcasting_dataset.geospatial import lat_lon_to_osgb +from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.data_sources.data_source import ImageDataSource +from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso + +from nowcasting_dataset.consts import GSP_ID, GSP_YIELD, GSP_X_COORDS, GSP_Y_COORDS, \ + DEFAULT_N_GSP_PER_EXAMPLE, OBJECT_AT_CENTER + +logger = logging.getLogger(__name__) + + +@dataclass +class GSPDataSource(ImageDataSource): + """ + Data source for GSP PV Data + + 30 mins data is taken from 'PV Live' from https://www.solar.sheffield.ac.uk/pvlive/ + meta data is taken from ESO + """ + + # filename of where the gsp data is stored + filename: Union[str, Path] + # start datetime, this can be None + start_dt: Optional[datetime] = None + # end datetime, this can be None + end_dt: Optional[datetime] = None + # the threshold where we only taken gsp's with a maximum power, above this value. + threshold_mw: int = 20 + # the frequency of the data + sample_period_minutes: int = 30 + # get the data for the gsp at the center too. + # This can be turned off if the center of the bounding box is of a pv system + get_center: bool = True + # the maximum number of gsp's to be loaded for data sample + n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE + + def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): + """ + Set random seed and load data + """ + super().__post_init__(image_size_pixels, meters_per_pixel) + seed = torch.initial_seed() + self.rng = np.random.default_rng(seed=seed) + self.load() + + def _get_sample_period_minutes(self): + """ Override the default sample minutes""" + return self.sample_period_minutes + + def load(self): + """ + Load the meta data and load the GSP power data + """ + + # load metadata + self.metadata = get_gsp_metadata_from_eso() + + # make location x,y in osgb + self.metadata["location_x"], self.metadata["location_y"] = lat_lon_to_osgb( + self.metadata["gsp_lat"], self.metadata["gsp_lon"] + ) + + # load gsp data from file / gcp + self.gsp_power = load_solar_gsp_data(self.filename, start_dt=self.start_dt, end_dt=self.end_dt) + + # drop any gsp below 20 MW (or set threshold). This is to get rid of any small GSP where predicting the + # solar output will be harder. + self.gsp_power, self.metadata = drop_gsp_by_threshold( + self.gsp_power, self.metadata, threshold_mw=self.threshold_mw + ) + + # scale from 0 to 1 + self.gsp_power = scale_to_0_to_1(self.gsp_power) + + logger.debug(f'There are {len(self.gsp_power.columns)} GSP') + + def datetime_index(self): + """ + Return the datetimes that are available + """ + return self.gsp_power.index + + def get_locations_for_batch(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: + """ + Get x and y locations for a batch. Assume that all data is available for all GSP. + Random GSP are taken, and the locations of them are returned. This is useful as other datasources need to know + which x,y locations to get + Returns: list of x and y locations + """ + + logger.debug("Getting locations for the batch") + + # Pick a random GSP for each t0_datetime, and then grab + # their geographical location. + x_locations = [] + y_locations = [] + + for t0_dt in t0_datetimes: + + # Choose start and end times + start_dt = self._get_start_dt(t0_dt) + end_dt = self._get_end_dt(t0_dt) + + # remove any nans + gsp_power = self.gsp_power.loc[start_dt:end_dt].dropna(axis="columns", how="any") + + # get random index + random_gsp_id = self.rng.choice(gsp_power.columns) + meta_data = self.metadata[(self.metadata["gsp_id"] == random_gsp_id)] + + # Make sure there is only one. Sometimes there are multiple gsp_ids at one location e.g. 'SELL_1'. + # Further investigation on this may be needed, but going to ignore this for now. + # + metadata_for_gsp = meta_data.iloc[0] + + # Get metadata for GSP + x_locations.append(metadata_for_gsp.location_x) + y_locations.append(metadata_for_gsp.location_y) + + logger.debug( + f"Found locations for GSP id {random_gsp_id} of {metadata_for_gsp.location_x} and " + f"{metadata_for_gsp.location_y}" + ) + + return x_locations, y_locations + + def get_example(self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number) -> Example: + """ + Get data example from one time point (t0_dt) and for x and y coords (x_meters_center), (y_meters_center). + + Get data at the location of x,y and get surrounding GSP power data also. + + Args: + t0_dt: datetime of "now". History and forecast are also returned + x_meters_center: x location of center GSP. + y_meters_center: y location of center GSP. + + Returns: Dictionary with GSP data in it + + """ + logger.debug("Getting example data") + + # get the GSP power, including history and forecast + selected_gsp_power = self._get_time_slice(t0_dt) + + # get the main gsp id, and the ids of the gsp in the bounding box + all_gsp_ids = self._get_gsp_ids_in_roi( + x_meters_center, y_meters_center, selected_gsp_power.columns + ) + if self.get_center: + central_gsp_id = self._get_central_gsp_id( + x_meters_center, y_meters_center, selected_gsp_power.columns + ) + assert central_gsp_id in all_gsp_ids + + # By convention, the 'target' GSP ID (the one in the center + # of the image) must be in the first position of the returned arrays. + all_gsp_ids = all_gsp_ids.drop(central_gsp_id) + all_gsp_ids = all_gsp_ids.insert(loc=0, item=central_gsp_id) + else: + logger.warning('Not getting center GSP') + + # only select at most {n_gsp_per_example} + all_gsp_ids = all_gsp_ids[: self.n_gsp_per_example] + + # select the GSP power output for the selected GSP IDs + selected_gsp_power = selected_gsp_power[all_gsp_ids] + + gsp_x_coords = self.metadata[self.metadata["gsp_id"].isin(all_gsp_ids)].location_x + gsp_y_coords = self.metadata[self.metadata["gsp_id"].isin(all_gsp_ids)].location_y + + # Save data into the Example dict... + example = Example( + gsp_id=all_gsp_ids, + gsp_yield=selected_gsp_power, + x_meters_center=x_meters_center, + y_meters_center=y_meters_center, + gsp_x_coords=gsp_x_coords, + gsp_y_coords=gsp_y_coords, + gsp_datetime_index=selected_gsp_power.index, + ) + + if self.get_center: + example[OBJECT_AT_CENTER] = 'gsp' + + # Pad (if necessary) so returned arrays are always of size n_gsp_per_example. + pad_size = self.n_gsp_per_example - len(all_gsp_ids) + example = pad_data(data=example, + one_dimensional_arrays=[GSP_ID, GSP_X_COORDS, GSP_Y_COORDS], + two_dimensional_arrays=[GSP_YIELD], + pad_size=pad_size) + + return example + + def _get_central_gsp_id( + self, x_meters_center: Number, y_meters_center: Number, gsp_ids_with_data_for_timeslice: pd.Int64Index + ) -> int: + """ + Get the GSP id of the central GSP from coordinates + Args: + x_meters_center: the location of the gsp (x) + y_meters_center: the location of the gsp (y) + gsp_ids_with_data_for_timeslice: List of gsp ids that are available for a certain timeslice + + Returns: GSP id + """ + + logger.debug("Getting Central GSP") + + # If x_meters_center and y_meters_center have been chosen + # by {}.get_locations_for_batch() then we just have + # to find the gsp_ids at that exact location. This is + # super-fast (a few hundred microseconds). We use np.isclose + # instead of the equality operator because floats. + meta_data_index = self.metadata.index[ + np.isclose(self.metadata.location_x, x_meters_center, rtol=1E-05, atol=1E-05) + & np.isclose(self.metadata.location_y, y_meters_center, rtol=1E-05, atol=1E-05) + ] + gsp_ids = self.metadata.loc[meta_data_index].gsp_id.values + + if len(gsp_ids) == 0: + # TODO: Implement finding GSP closest to x_meters_center, + # y_meters_center. This will probably be quite slow, so always + # try finding an exact match first (which is super-fast). + raise NotImplementedError( + "Not yet implemented the ability to find GSP *nearest*" + " (but not at the identical location to) x_meters_center and" + " y_meters_center." + ) + + gsp_ids = gsp_ids_with_data_for_timeslice.intersection(gsp_ids) + + if len(gsp_ids) == 0: + raise NotImplementedError( + f"Could not find GSP id for {x_meters_center}, {y_meters_center} " + f"({gsp_ids}) and {gsp_ids_with_data_for_timeslice}" + ) + + return int(gsp_ids[0]) + + def _get_gsp_ids_in_roi( + self, x_meters_center: Number, y_meters_center: Number, gsp_ids_with_data_for_timeslice: pd.Int64Index + ) -> pd.Int64Index: + """ + Find the GSP IDs for all the GSP within the geospatial region of interest, defined by self.square. + Args: + x_meters_center: center of area of interest (x coords) + y_meters_center: center of area of interest (y coords) + gsp_ids_with_data_for_timeslice: ids that are avialble for a specific time slice + + Returns: list of GSP ids that are in area of interest + + """ + + logger.debug("Getting all gsp in ROI") + + # creating bounding box + bounding_box = self._square.bounding_box_centered_on( + x_meters_center=x_meters_center, y_meters_center=y_meters_center + ) + + # get all x and y locations of gsp + x = self.metadata.location_x + y = self.metadata.location_y + + # make mask of gsp_ids + mask = get_bounding_box_mask(bounding_box, x, y) + + gsp_ids = self.metadata[mask].gsp_id + gsp_ids = gsp_ids_with_data_for_timeslice.intersection(gsp_ids) + + assert len(gsp_ids) > 0 + return gsp_ids + + def _get_time_slice(self, t0_dt: pd.Timestamp) -> [pd.DataFrame]: + """ + Get time slice of GSP power data for give time. + Note the time is extended backwards by history lenght and forward by prediction time + Args: + t0_dt: timestamp of interest + + Returns: pandas data frame of GSP power data + """ + + logger.debug(f'Getting power slice for {t0_dt}') + + # get start and end datetime, takening into account history and forecast length. + start_dt = self._get_start_dt(t0_dt) + end_dt = self._get_end_dt(t0_dt) + + # select power for certain times + power = self.gsp_power.loc[start_dt:end_dt] + + # remove any nans + power = power.dropna(axis="columns", how="any") + + logger.debug(f'Found {len(power.columns)} GSP') + + return power + + +def drop_gsp_by_threshold(gsp_power: pd.DataFrame, meta_data: pd.DataFrame, threshold_mw: int = 20): + """ + Drop GSP where the max power is below a certain threshold + Args: + gsp_power: GSP power data + meta_data: the GSP meta data + threshold_mw: the threshold where we only taken GSP with a maximum power, above this value. + + Returns: power data and metadata + """ + maximum_gsp = gsp_power.max() + + keep_index = maximum_gsp >= threshold_mw + + logger.debug(f"Dropping {sum(~keep_index)} GSPs as maximum is not greater {threshold_mw} MW") + logger.debug(f"Keeping {sum(keep_index)} GSPs as maximum is greater {threshold_mw} MW") + + gsp_power = gsp_power[keep_index.index] + gsp_ids = gsp_power.columns + meta_data = meta_data[meta_data["gsp_id"].isin(gsp_ids)] + + return gsp_power[keep_index.index], meta_data + + +def load_solar_gsp_data( + filename: Union[str, Path], start_dt: Optional[datetime] = None, end_dt: Optional[datetime] = None +) -> pd.DataFrame: + """ + Load solar PV GSP data + + Args: + filename: filename of file to be loaded, can put 'gs://' files in here too + start_dt: the start datetime, which to trim the data to + end_dt: the end datetime, which to trim the data to + + Returns:dataframe of pv data + + """ + + logger.debug(f"Loading Solar GSP Data from GCS {filename} from {start_dt} to {end_dt}") + # Open data - it may be quicker to open byte file first, but decided just to keep it like this at the moment + gsp_power = xr.open_dataset(filename, engine="zarr") + gsp_power = gsp_power.sel(datetime_gmt=slice(start_dt, end_dt)) + gsp_power_df = gsp_power.to_dataframe() + + # Save memory + del gsp_power + + # Process the data a little + gsp_power_df = gsp_power_df.dropna(axis="columns", how="all") + gsp_power_df = gsp_power_df.clip(lower=0, upper=5e7) + + # make column names ints, not strings + gsp_power_df.columns = [int(col) for col in gsp_power_df.columns] + + return gsp_power_df diff --git a/nowcasting_dataset/data_sources/gsp/pvlive.py b/nowcasting_dataset/data_sources/gsp/pvlive.py new file mode 100644 index 00000000..6206acf0 --- /dev/null +++ b/nowcasting_dataset/data_sources/gsp/pvlive.py @@ -0,0 +1,82 @@ +from datetime import datetime, timedelta +import logging +import pandas as pd +from pvlive_api import PVLive + +from nowcasting_dataset.data_sources.gsp.eso import get_list_of_gsp_ids + +logger = logging.getLogger(__name__) + +CHUNK_DURATION = timedelta(days=30) + + +def load_pv_gsp_raw_data_from_pvlive(start: datetime, end: datetime, number_of_gsp: int = None) -> pd.DataFrame: + """ + Load raw pv gsp data from pvlive. Note that each gsp is loaded separately. Also the data is loaded in 30 day chunks. + Args: + start: the start date for gsp data to load + end: the end date for gsp data to load + number_of_gsp: The number of gsp to load. Note that on 2021-09-01 there were 338 to load. + + Returns: Data frame of time series of gsp data. Shows PV data for each GSP from {start} to {end} + + """ + + # get a lit of gsp ids + gsp_ids = get_list_of_gsp_ids(maximum_number_of_gsp=number_of_gsp) + + # setup pv Live class, although here we are getting historic data + pvl = PVLive() + + # set the first chunk of data, note that 30 day chunks are used except if the end time is smaller than that + first_start_chunk = start + first_end_chunk = min([first_start_chunk + CHUNK_DURATION, end]) + + gsp_data_df = [] + logger.debug(f'Will be getting data for {len(gsp_ids)} gsp ids') + # loop over gsp ids + for gsp_id in gsp_ids: + + one_gsp_data_df = [] + + # set the first chunk start and end times + start_chunk = first_start_chunk + end_chunk = first_end_chunk + + # loop over 30 days chunks (nice to see progress instead of waiting a long time for one command - this might + # not be the fastest) + while start_chunk <= end: + logger.debug(f"Getting data for gsp id {gsp_id} from {start_chunk} to {end_chunk}") + + one_gsp_data_df.append( + pvl.between( + start=start_chunk, end=end_chunk, entity_type="gsp", entity_id=gsp_id, extra_fields="", dataframe=True + ) + ) + + # add 30 days to the chunk, to get the next chunk + start_chunk = start_chunk + CHUNK_DURATION + end_chunk = end_chunk + CHUNK_DURATION + + if end_chunk > end: + end_chunk = end + + # join together one gsp data, and sort + one_gsp_data_df = pd.concat(one_gsp_data_df) + one_gsp_data_df = one_gsp_data_df.sort_values(by=["gsp_id", "datetime_gmt"]) + + # append to longer list + gsp_data_df.append(one_gsp_data_df) + + gsp_data_df = pd.concat(gsp_data_df) + + # remove any extra data loaded + gsp_data_df = gsp_data_df[gsp_data_df["datetime_gmt"] <= end] + + # remove any duplicates + gsp_data_df.drop_duplicates(inplace=True) + + # format data, remove timezone, + gsp_data_df['datetime_gmt'] = gsp_data_df['datetime_gmt'].dt.tz_localize(None) + + return gsp_data_df diff --git a/nowcasting_dataset/data_sources/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp_data_source.py index b2c03c7b..1a3f84c4 100644 --- a/nowcasting_dataset/data_sources/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp_data_source.py @@ -1,5 +1,5 @@ from nowcasting_dataset.data_sources.data_source import ZarrDataSource -from nowcasting_dataset.example import Example, to_numpy +from nowcasting_dataset.dataset.example import Example, to_numpy from nowcasting_dataset import utils from typing import Iterable, Optional, List import xarray as xr diff --git a/nowcasting_dataset/data_sources/pv_data_source.py b/nowcasting_dataset/data_sources/pv_data_source.py index f8dfe15a..15e4b891 100644 --- a/nowcasting_dataset/data_sources/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv_data_source.py @@ -1,10 +1,14 @@ +from nowcasting_dataset.consts import PV_SYSTEM_ID, PV_SYSTEM_ROW_NUMBER, PV_SYSTEM_X_COORDS, PV_SYSTEM_Y_COORDS, \ + PV_AZIMUTH_ANGLE, PV_ELEVATION_ANGLE, PV_YIELD, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, OBJECT_AT_CENTER from nowcasting_dataset.data_sources.data_source import ImageDataSource -from nowcasting_dataset.example import Example +from nowcasting_dataset.dataset.example import Example from nowcasting_dataset import geospatial, utils +from nowcasting_dataset.square import get_bounding_box_mask from dataclasses import dataclass import pandas as pd import numpy as np import torch +from tqdm import tqdm from numbers import Number from typing import List, Tuple, Union, Optional import datetime @@ -15,17 +19,10 @@ import functools import logging import time +from concurrent import futures logger = logging.getLogger(__name__) -PV_SYSTEM_ID = 'pv_system_id' -PV_SYSTEM_ROW_NUMBER = 'pv_system_row_number' -PV_SYSTEM_X_COORDS = 'pv_system_x_coords' -PV_SYSTEM_Y_COORDS = 'pv_system_y_coords' -PV_AZIMUTH_ANGLE = 'pv_azimuth_angle' -PV_ELEVATION_ANGLE = 'pv_elevation_angle' -PV_YIELD = 'pv_yield' - @dataclass class PVDataSource(ImageDataSource): @@ -36,9 +33,10 @@ class PVDataSource(ImageDataSource): random_pv_system_for_given_location: Optional[bool] = True #: Each example will always have this many PV systems. #: If less than this number exist in the data then pad with NaNs. - n_pv_systems_per_example: int = 128 + n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE load_azimuth_and_elevation: bool = False load_from_gcs: bool = True # option to load data from gcs, or local file + get_center: bool = True def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): super().__post_init__(image_size_pixels, meters_per_pixel) @@ -82,6 +80,9 @@ def _load_pv_power(self): logger.debug('Loading PV Power data') + if 'gs://' not in str(self.filename): + self.load_from_gcs = False + pv_power = load_solar_pv_data_from_gcs( self.filename, start_dt=self.start_dt, end_dt=self.end_dt, from_gcs=self.load_from_gcs) @@ -164,18 +165,24 @@ def _get_all_pv_system_ids_in_roi( ) -> pd.Int64Index: """Find the PV system IDs for all the PV systems within the geospatial region of interest, defined by self.square.""" + + logger.debug(f'Getting PV example data for {x_meters_center} and {y_meters_center}') + bounding_box = self._square.bounding_box_centered_on( x_meters_center=x_meters_center, y_meters_center=y_meters_center) x = self.pv_metadata.location_x y = self.pv_metadata.location_y - pv_system_ids = self.pv_metadata.index[ - (x >= bounding_box.left) & - (x <= bounding_box.right) & - (y >= bounding_box.bottom) & - (y <= bounding_box.top)] + + # make mask of pv system ids + mask = get_bounding_box_mask(bounding_box, x, y) + pv_system_ids = self.pv_metadata.index[mask] + pv_system_ids = pv_system_ids_with_data_for_timeslice.intersection( pv_system_ids) - assert len(pv_system_ids) > 0 + + # there may not be any pv systems in a GSP region + # assert len(pv_system_ids) > 0 + return pv_system_ids def get_example( @@ -184,17 +191,20 @@ def get_example( x_meters_center: Number, y_meters_center: Number) -> Example: + logger.debug('Getting PV example data') + selected_pv_power, selected_pv_azimuth_angle, selected_pv_elevation_angle = self._get_time_slice(t0_dt) - central_pv_system_id = self._get_central_pv_system_id( - x_meters_center, y_meters_center, selected_pv_power.columns) all_pv_system_ids = self._get_all_pv_system_ids_in_roi( x_meters_center, y_meters_center, selected_pv_power.columns) + if self.get_center: + central_pv_system_id = self._get_central_pv_system_id( + x_meters_center, y_meters_center, selected_pv_power.columns) - # By convention, the 'target' PV system ID (the one in the center - # of the image) must be in the first position of the returned arrays. - all_pv_system_ids = all_pv_system_ids.drop(central_pv_system_id) - all_pv_system_ids = all_pv_system_ids.insert( - loc=0, item=central_pv_system_id) + # By convention, the 'target' PV system ID (the one in the center + # of the image) must be in the first position of the returned arrays. + all_pv_system_ids = all_pv_system_ids.drop(central_pv_system_id) + all_pv_system_ids = all_pv_system_ids.insert( + loc=0, item=central_pv_system_id) all_pv_system_ids = all_pv_system_ids[:self.n_pv_systems_per_example] @@ -215,30 +225,32 @@ def get_example( x_meters_center=x_meters_center, y_meters_center=y_meters_center, pv_system_x_coords=pv_system_x_coords, - pv_system_y_coords=pv_system_y_coords) + pv_system_y_coords=pv_system_y_coords, + pv_datetime_index=selected_pv_power.index) if self.load_azimuth_and_elevation: example[PV_AZIMUTH_ANGLE] = selected_pv_azimuth_angle example[PV_ELEVATION_ANGLE] = selected_pv_elevation_angle - # Pad (if necessary) so returned arrays are always of size - # n_pv_systems_per_example. + if self.get_center: + example[OBJECT_AT_CENTER] = 'pv' + + # Pad (if necessary) so returned arrays are always of size n_pv_systems_per_example. pad_size = self.n_pv_systems_per_example - len(all_pv_system_ids) - pad_shape = (0, pad_size) # (before, after) + one_dimensional_arrays = [ - PV_SYSTEM_ID, PV_SYSTEM_ROW_NUMBER, - PV_SYSTEM_X_COORDS, PV_SYSTEM_Y_COORDS] - for name in one_dimensional_arrays: - example[name] = utils.pad_nans(example[name], pad_width=pad_shape) + PV_SYSTEM_ID, PV_SYSTEM_ROW_NUMBER, + PV_SYSTEM_X_COORDS, PV_SYSTEM_Y_COORDS] + pad_nans_variables = [PV_YIELD] if self.load_azimuth_and_elevation: pad_nans_variables.append(PV_AZIMUTH_ANGLE) pad_nans_variables.append(PV_ELEVATION_ANGLE) - for variable in pad_nans_variables: - example[variable] = utils.pad_nans( - example[variable], - pad_width=((0, 0), pad_shape)) # (axis0, axis1) + example = utils.pad_data(data=example, + one_dimensional_arrays=one_dimensional_arrays, + two_dimensional_arrays=pad_nans_variables, + pad_size=pad_size) return example @@ -289,32 +301,58 @@ def _calculate_azimuth_and_elevation(self): logger.debug('Calculating azimuth and elevation angles') - datestamps = self.datetime_index().to_pydatetime() + self.pv_azimuth, self.pv_elevation \ + = calculate_azimuth_and_elevation_all_pv_systems(self.datetime_index().to_pydatetime(), self.pv_metadata) + + +def calculate_azimuth_and_elevation_all_pv_systems(datestamps: List[datetime.datetime], pv_metadata: pd.DataFrame) -> (pd.Series, pd.Series): + """ + Calculate the azimuth and elevation angles for each datestamp, for each pv system. + """ + + logger.debug(f'Will be calculating for {len(datestamps)} datestamps and {len(pv_metadata)} pv systems') + + # create array of index datetime, columns of system_id for both azimuth and elevation + pv_azimuth = [] + pv_elevation = [] + + t = time.time() + # loop over all metadata and fine azimuth and elevation angles, + # not sure this is the best method to use, as currently this step takes ~2 minute for 745 pv systems, + # and 235 datestamps (~100,000 point). But this only needs to be done once. + with futures.ThreadPoolExecutor(max_workers=len(pv_metadata)) as executor: + + logger.debug('Setting up jobs') + + # Submit tasks to the executor. + future_azimuth_and_elevation_per_pv_system = [] + for i in tqdm(range(len(pv_metadata))): + future_azimuth_and_elevation = executor.submit( + geospatial.calculate_azimuth_and_elevation_angle, + latitude=pv_metadata.iloc[i].latitude, + longitude=pv_metadata.iloc[i].longitude, + datestamps=datestamps) + future_azimuth_and_elevation_per_pv_system.append([future_azimuth_and_elevation, pv_metadata.iloc[i].name]) - # create array of index datetime, columns of system_id for both azimuth and elevation - pv_azimuth = [] - pv_elevation = [] + logger.debug(f'Getting results') - t = time.time() - # loop over all metadata and fine azimuth and elevation angles, - # not sure this is the best method to use, as currently this step takes ~2 minute for 745 pv systems, - # and 235 datestamps (~100,000 point). But this only needs to be done once. - for i in range(0, len(self.pv_metadata)): + # Collect results from each thread. + for i in tqdm(range(len(future_azimuth_and_elevation_per_pv_system))): + future_azimuth_and_elevation, name = future_azimuth_and_elevation_per_pv_system[i] + azimuth_and_elevation = future_azimuth_and_elevation.result() - row = self.pv_metadata.iloc[i] + azimuth = azimuth_and_elevation.loc[:, 'azimuth'].rename(name) + elevation = azimuth_and_elevation.loc[:, 'elevation'].rename(name) - azimuth_and_elevation \ - = geospatial.calculate_azimuth_and_elevation_angle(latitude=row.latitude, - longitude=row.longitude, - datestamps=datestamps) + pv_azimuth.append(azimuth) + pv_elevation.append(elevation) - pv_azimuth.append(azimuth_and_elevation.loc[:, 'azimuth'].rename(row.name)) - pv_elevation.append(azimuth_and_elevation.loc[:, 'elevation'].rename(row.name)) + pv_azimuth = pd.concat(pv_azimuth, axis=1) + pv_elevation = pd.concat(pv_elevation, axis=1) - self.pv_azimuth = pd.concat(pv_azimuth, axis=1) - self.pv_elevation = pd.concat(pv_elevation, axis=1) + logger.debug(f'Calculated Azimuth and Elevation angles in {time.time() - t} seconds') - logger.debug(f'Calculated Azimuth and Elevation angles in {time.time() - t} seconds') + return pv_azimuth, pv_elevation def load_solar_pv_data_from_gcs( diff --git a/nowcasting_dataset/data_sources/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite_data_source.py index 6e138baa..8e1d5c1e 100644 --- a/nowcasting_dataset/data_sources/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite_data_source.py @@ -1,5 +1,5 @@ from nowcasting_dataset.data_sources.data_source import ZarrDataSource -from nowcasting_dataset.example import Example, to_numpy +from nowcasting_dataset.dataset.example import Example, to_numpy from nowcasting_dataset import utils from typing import Iterable, Optional, List from numbers import Number diff --git a/nowcasting_dataset/dataset/README.md b/nowcasting_dataset/dataset/README.md new file mode 100644 index 00000000..545e24d8 --- /dev/null +++ b/nowcasting_dataset/dataset/README.md @@ -0,0 +1,31 @@ +# Datasets + +This folder contains the following files + +## batch.py + +Functions used to 'play with' batch data, where "batch data" is a List of Example objects; i.e. `List[Example]`. + +## datamodule.py + +Contains a class NowcastingDataModule - pl.LightningDataModule +This handles the + - amalgamation of all different data sources, + - making valid datetimes across all the sources, + - splitting into train and validation datasets + + +## datasets.py + +This file contains the following classes + +NetCDFDataset - torch.utils.data.Dataset: Use for loading pre-made batches +NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches +ContiguousNowcastingDataset - NowcastingDataset + +## example.py + +Main thing in here is a Typed Dictionary. This is used to store one element of data use for one step in the ML models. +There is also a validation function. See this file for documentation about exactly what data is available in each ML +training Example. + diff --git a/nowcasting_dataset/dataset/batch.py b/nowcasting_dataset/dataset/batch.py new file mode 100644 index 00000000..926c02dc --- /dev/null +++ b/nowcasting_dataset/dataset/batch.py @@ -0,0 +1,155 @@ +from typing import List, Optional +import logging + +import numpy as np +import xarray as xr +from pathlib import Path + +from nowcasting_dataset.consts import GSP_ID, GSP_YIELD, GSP_X_COORDS, GSP_Y_COORDS, \ + DATETIME_FEATURE_NAMES + +from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.utils import get_netcdf_filename + +_LOG = logging.getLogger(__name__) + +LOCAL_TEMP_PATH = Path('~/temp/').expanduser() + + +def write_batch_locally(batch: List[Example], batch_i: int): + """ + Write a batch to a locally file + Args: + batch: batch of data + batch_i: the number of the batch + + """ + dataset = batch_to_dataset(batch) + dataset = fix_dtypes(dataset) + encoding = {name: {"compression": "lzf"} for name in dataset.data_vars} + filename = get_netcdf_filename(batch_i) + local_filename = LOCAL_TEMP_PATH / filename + dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) + + +def fix_dtypes(concat_ds): + """ + TODO + """ + ds_dtypes = { + "example": np.int32, + "sat_x_coords": np.int32, + "sat_y_coords": np.int32, + "nwp": np.float32, + "nwp_x_coords": np.float32, + "nwp_y_coords": np.float32, + "pv_system_id": np.float32, + "pv_system_row_number": np.float32, + "pv_system_x_coords": np.float32, + "pv_system_y_coords": np.float32, + } + + for name, dtype in ds_dtypes.items(): + concat_ds[name] = concat_ds[name].astype(dtype) + + assert concat_ds["sat_data"].dtype == np.int16 + return concat_ds + + +def batch_to_dataset(batch: List[Example]) -> xr.Dataset: + """Concat all the individual fields in an Example into a single Dataset. + + Args: + batch: List of Example objects, which together constitute a single batch. + """ + datasets = [] + for i, example in enumerate(batch): + try: + individual_datasets = [] + example_dim = {"example": np.array([i], dtype=np.int32)} + for name in ["sat_data", "nwp"]: + ds = example[name].to_dataset(name=name) + short_name = name.replace("_data", "") + if name == "nwp": + ds = ds.rename({"target_time": "time"}) + for dim in ["time", "x", "y"]: + ds = coord_to_range(ds, dim, prefix=short_name) + ds = ds.rename( + { + "variable": f"{short_name}_variable", + "x": f"{short_name}_x", + "y": f"{short_name}_y", + } + ) + individual_datasets.append(ds) + + # Datetime features + for name in DATETIME_FEATURE_NAMES: + ds = example[name].rename(name).to_xarray().to_dataset().rename({"index": "time"}) + ds = coord_to_range(ds, "time", prefix=None) + individual_datasets.append(ds) + + # PV + one_dateset = xr.DataArray(example["pv_yield"], dims=["time", "pv_system"]) + one_dateset = one_dateset.to_dataset(name="pv_yield") + n_pv_systems = len(example["pv_system_id"]) + + # GSP + n_gsp_systems = len(example[GSP_ID]) + one_dateset['gsp_yield'] = xr.DataArray(example[GSP_YIELD], dims=["time_30", "gsp_system"]) + + # This will expand all dataarrays to have an 'example' dim. + # 0D + for name in ["x_meters_center", "y_meters_center"]: + try: + one_dateset[name] = xr.DataArray([example[name]], coords=example_dim, dims=["example"]) + except Exception as e: + _LOG.error(f'Could not make pv_yield data for {name} with example_dim={example_dim}') + if name not in example.keys(): + _LOG.error(f'{name} not in data keys: {example.keys()}') + _LOG.error(e) + raise Exception + + # 1D + for name in ["pv_system_id", "pv_system_row_number", "pv_system_x_coords", "pv_system_y_coords"]: + one_dateset[name] = xr.DataArray( + example[name][None, :], + coords={**example_dim, **{"pv_system": np.arange(n_pv_systems, dtype=np.int32)}}, + dims=["example", "pv_system"], + ) + + # GSP + for name in [GSP_ID, GSP_X_COORDS, GSP_Y_COORDS]: + try: + one_dateset[name] = xr.DataArray( + example[name][None, :], + coords={**example_dim, **{"gsp_system": np.arange(n_gsp_systems, dtype=np.int32)}}, + dims=["example", "gsp_system"], + ) + except Exception as e: + _LOG.debug(f'Could not add {name} to dataset. {example[name].shape}') + _LOG.error(e) + raise e + + individual_datasets.append(one_dateset) + + # Merge + merged_ds = xr.merge(individual_datasets) + datasets.append(merged_ds) + + except Exception as e: + print(e) + _LOG.error(e) + raise Exception + + return xr.concat(datasets, dim="example") + + +def coord_to_range(da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32) -> xr.DataArray: + # TODO: Actually, I think this is over-complicated? I think we can + # just strip off the 'coord' from the dimension. + coord = da[dim] + da[dim] = np.arange(len(coord), dtype=dtype) + if prefix is not None: + da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim]) + return da \ No newline at end of file diff --git a/nowcasting_dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py similarity index 54% rename from nowcasting_dataset/datamodule.py rename to nowcasting_dataset/dataset/datamodule.py index 2d8833a6..1b558062 100644 --- a/nowcasting_dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -3,17 +3,21 @@ import pandas as pd from copy import deepcopy import torch +import logging from nowcasting_dataset import data_sources +from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset import time as nd_time from nowcasting_dataset import utils from nowcasting_dataset import consts -from nowcasting_dataset import dataset +from nowcasting_dataset.dataset import datasets from dataclasses import dataclass import warnings + with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) import pytorch_lightning as pl +logger = logging.getLogger(__name__) @dataclass class NowcastingDataModule(pl.LightningDataModule): @@ -25,19 +29,19 @@ class NowcastingDataModule(pl.LightningDataModule): train_t0_datetimes: pd.DatetimeIndex val_t0_datetimes: pd.DatetimeIndex """ + pv_power_filename: Optional[Union[str, Path]] = None pv_metadata_filename: Optional[Union[str, Path]] = None batch_size: int = 8 n_training_batches_per_epoch: int = 25_000 n_validation_batches_per_epoch: int = 1_000 - history_len: int = 2 #: Number of timesteps of history, not including t0. - forecast_len: int = 12 #: Number of timesteps of forecast, not including t0. + history_minutes: int = 30 #: Number of minutes of history, not including t0. + forecast_minutes: int = 60 #: Number of minutes of forecast, not including t0. sat_filename: Union[str, Path] = consts.SAT_FILENAME - sat_channels: Iterable[str] = ('HRV', ) + sat_channels: Iterable[str] = ("HRV",) normalise_sat: bool = True nwp_base_path: Optional[str] = None - nwp_channels: Optional[Iterable[str]] = ( - 't', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc') + nwp_channels: Optional[Iterable[str]] = ("t", "dswrf", "prate", "r", "sde", "si10", "vis", "lcc", "mcc", "hcc") image_size_pixels: int = 128 #: Passed to Data Sources. meters_per_pixel: int = 2000 #: Passed to Data Sources. convert_to_numpy: bool = True #: Passed to Data Sources. @@ -46,74 +50,107 @@ class NowcastingDataModule(pl.LightningDataModule): prefetch_factor: int = 64 #: Passed to DataLoader. n_samples_per_timestep: int = 2 #: Passed to NowcastingDataset collate_fn: Callable = torch.utils.data._utils.collate.default_collate #: Passed to NowcastingDataset + gsp_filename: Optional[Union[str, Path]] = None + train_validation_percentage_split: float = 20 + pv_load_azimuth_and_elevation: bool = False - skip_n_train_batches: int = 0 # number of train batches to skip + skip_n_train_batches: int = 0 # number of train batches to skip skip_n_validation_batches: int = 0 # number of validation batches to skip def __post_init__(self): super().__init__() + + self.history_len_30_minutes = self.history_minutes // 30 + self.forecast_len_30_minutes = self.forecast_minutes // 30 + + self.history_len_5_minutes = self.history_minutes // 5 + self.forecast_len_5_minutes = self.forecast_minutes // 5 + # Plus 1 because neither history_len nor forecast_len include t0. - self._total_seq_len = self.history_len + self.forecast_len + 1 + self._total_seq_len_5_minutes = self.history_len_5_minutes + self.forecast_len_5_minutes + 1 + self._total_seq_len_30_minutes = self.history_len_30_minutes + self.forecast_len_30_minutes + 1 self.contiguous_dataset = None if self.num_workers == 0: self.prefetch_factor = 2 # Set to default when not using multiprocessing. def prepare_data(self) -> None: # Satellite data - n_timesteps_per_batch = ( - self.batch_size // self.n_samples_per_timestep) + n_timesteps_per_batch = self.batch_size // self.n_samples_per_timestep self.sat_data_source = data_sources.SatelliteDataSource( filename=self.sat_filename, image_size_pixels=self.image_size_pixels, meters_per_pixel=self.meters_per_pixel, - history_len=self.history_len, - forecast_len=self.forecast_len, + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, channels=self.sat_channels, n_timesteps_per_batch=n_timesteps_per_batch, convert_to_numpy=self.convert_to_numpy, - normalise=self.normalise_sat) + normalise=self.normalise_sat, + ) self.data_sources = [self.sat_data_source] + sat_datetimes = self.sat_data_source.datetime_index() # PV if self.pv_power_filename is not None: - sat_datetimes = self.sat_data_source.datetime_index() self.pv_data_source = data_sources.PVDataSource( filename=self.pv_power_filename, metadata_filename=self.pv_metadata_filename, start_dt=sat_datetimes[0], end_dt=sat_datetimes[-1], - history_len=self.history_len, - forecast_len=self.forecast_len, + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, convert_to_numpy=self.convert_to_numpy, image_size_pixels=self.image_size_pixels, - meters_per_pixel=self.meters_per_pixel) + meters_per_pixel=self.meters_per_pixel, + get_center=False, + load_azimuth_and_elevation=self.pv_load_azimuth_and_elevation, + ) self.data_sources = [self.pv_data_source, self.sat_data_source] + if self.gsp_filename is not None: + self.gsp_data_source = GSPDataSource( + filename=self.gsp_filename, + start_dt=sat_datetimes[0], + end_dt=sat_datetimes[-1], + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, + convert_to_numpy=self.convert_to_numpy, + image_size_pixels=self.image_size_pixels, + meters_per_pixel=self.meters_per_pixel, + get_center=True + ) + + # put gsp data source at the start, so data is centered around GSP. This is the current approach, + # but in the future we may take a mix of GSP and PV data as the centroid + self.data_sources = [self.gsp_data_source] + self.data_sources + # NWP data if self.nwp_base_path is not None: self.nwp_data_source = data_sources.NWPDataSource( filename=self.nwp_base_path, image_size_pixels=2, meters_per_pixel=self.meters_per_pixel, - history_len=self.history_len, - forecast_len=self.forecast_len, + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, channels=self.nwp_channels, n_timesteps_per_batch=n_timesteps_per_batch, - convert_to_numpy=self.convert_to_numpy) + convert_to_numpy=self.convert_to_numpy, + ) self.data_sources.append(self.nwp_data_source) self.datetime_data_source = data_sources.DatetimeDataSource( - history_len=self.history_len, - forecast_len=self.forecast_len, - convert_to_numpy=self.convert_to_numpy) + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, + convert_to_numpy=self.convert_to_numpy + ) self.data_sources.append(self.datetime_data_source) - def setup(self, stage='fit'): + def setup(self, stage="fit"): """Split data, etc. Args: @@ -163,27 +200,30 @@ def setup(self, stage='fit'): self._split_data() # Create datasets - self.train_dataset = dataset.NowcastingDataset( + logger.debug('Making train dataset') + self.train_dataset = datasets.NowcastingDataset( t0_datetimes=self.train_t0_datetimes, data_sources=self.data_sources, skip_batch_index=self.skip_n_train_batches, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker( - self.n_training_batches_per_epoch)), - **self._common_dataset_params()) - self.val_dataset = dataset.NowcastingDataset( + n_batches_per_epoch_per_worker=(self._n_batches_per_epoch_per_worker(self.n_training_batches_per_epoch)), + **self._common_dataset_params(), + ) + logger.debug('Making validation dataset') + self.val_dataset = datasets.NowcastingDataset( t0_datetimes=self.val_t0_datetimes, data_sources=self.data_sources, skip_batch_index=self.skip_n_validation_batches, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker( - self.n_validation_batches_per_epoch)), - **self._common_dataset_params()) + n_batches_per_epoch_per_worker=(self._n_batches_per_epoch_per_worker(self.n_validation_batches_per_epoch)), + **self._common_dataset_params(), + ) + logger.debug('Making validation dataset: done') if self.num_workers == 0: self.train_dataset.per_worker_init(worker_id=0) self.val_dataset.per_worker_init(worker_id=0) + logger.debug('Setup: done') + def _n_batches_per_epoch_per_worker(self, n_batches_per_epoch: int) -> int: if self.num_workers > 0: return n_batches_per_epoch // self.num_workers @@ -192,93 +232,134 @@ def _n_batches_per_epoch_per_worker(self, n_batches_per_epoch: int) -> int: def _split_data(self): """Sets self.train_t0_datetimes and self.val_t0_datetimes.""" + + logger.debug('Going to split data') + self._check_has_prepared_data() - all_datetimes = self._get_datetimes() - t0_datetimes = nd_time.get_t0_datetimes( - datetimes=all_datetimes, total_seq_len=self._total_seq_len, - history_len=self.history_len) - del all_datetimes + self.t0_datetimes = self._get_datetimes(interpolate_for_30_minute_data=True) + + logger.debug(f'Got all start times, there are {len(self.t0_datetimes)}') + + # del all_datetimes # Split t0_datetimes into train and test. # TODO: Better way to split into train and val date ranges! # See https://github.com/openclimatefix/nowcasting_dataset/issues/7 - assert len(t0_datetimes) > 5 - split = len(t0_datetimes) // 5 + + logger.debug(f'Taking {self.train_validation_percentage_split}% into validation') + + split_number = int(100 / self.train_validation_percentage_split) + assert len(self.t0_datetimes) > split_number + split = len(self.t0_datetimes) // split_number assert split > 0 - split = len(t0_datetimes) - split - self.train_t0_datetimes = t0_datetimes[:split] - self.val_t0_datetimes = t0_datetimes[split:] + split = len(self.t0_datetimes) - split + + # set train and validation times + self.train_t0_datetimes = self.t0_datetimes[:split] + self.val_t0_datetimes = self.t0_datetimes[split:] + + logger.debug(f'Split data done, train has {len(self.train_t0_datetimes)}, ' + f'validation has {len(self.val_t0_datetimes)}') def train_dataloader(self) -> torch.utils.data.DataLoader: - return torch.utils.data.DataLoader( - self.train_dataset, **self._common_dataloader_params()) + return torch.utils.data.DataLoader(self.train_dataset, **self._common_dataloader_params()) def val_dataloader(self) -> torch.utils.data.DataLoader: - return torch.utils.data.DataLoader( - self.val_dataset, **self._common_dataloader_params()) + return torch.utils.data.DataLoader(self.val_dataset, **self._common_dataloader_params()) def contiguous_dataloader(self) -> torch.utils.data.DataLoader: if self.contiguous_dataset is None: pv_data_source = deepcopy(self.pv_data_source) pv_data_source.random_pv_system_for_given_location = False data_sources = [pv_data_source, self.sat_data_source] - self.contiguous_dataset = dataset.ContiguousNowcastingDataset( + self.contiguous_dataset = datasets.ContiguousNowcastingDataset( t0_datetimes=self.val_t0_datetimes, data_sources=data_sources, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker(32)), - **self._common_dataset_params()) + n_batches_per_epoch_per_worker=(self._n_batches_per_epoch_per_worker(32)), + **self._common_dataset_params(), + ) if self.num_workers == 0: self.contiguous_dataset.per_worker_init(worker_id=0) - return torch.utils.data.DataLoader( - self.contiguous_dataset, **self._common_dataloader_params()) + return torch.utils.data.DataLoader(self.contiguous_dataset, **self._common_dataloader_params()) def _common_dataset_params(self) -> Dict: return dict( - batch_size=self.batch_size, - n_samples_per_timestep=self.n_samples_per_timestep, - collate_fn=self.collate_fn) + batch_size=self.batch_size, n_samples_per_timestep=self.n_samples_per_timestep, collate_fn=self.collate_fn + ) def _common_dataloader_params(self) -> Dict: return dict( pin_memory=self.pin_memory, num_workers=self.num_workers, - worker_init_fn=dataset.worker_init_fn, + worker_init_fn=datasets.worker_init_fn, prefetch_factor=self.prefetch_factor, - # Disable automatic batching because NowcastingDataset.__iter__ # returns complete batches batch_size=None, - batch_sampler=None) + batch_sampler=None, + ) - def _get_datetimes(self) -> pd.DatetimeIndex: + def _get_datetimes(self, interpolate_for_30_minute_data: bool = False, adjust_for_sequence_length: bool = True) -> pd.DatetimeIndex: """Compute the datetime index. + interpolate_for_30_minute_data: If True, + 1. all datetimes from source will be interpolated to 5 min intervals, + 2. the total intersection will be taken + 3. only 30 mins datetimes will be selected + + adjust_for_sequence_length, if true, adjust the datetimes by sequence history and length. + This means that all the datetimes from [datetime - history_delta: datetime + forecast_delta] should be available + + This deals with a mixture of data sources that have 5 mins and 30 min datatime. + Returns the intersection of the datetime indicies of all the data_sources, filtered by daylight hours.""" + logger.debug('Get the datetimes') self._check_has_prepared_data() # Get the intersection of datetimes from all data sources. all_datetime_indexes = [] for data_source in self.data_sources: + logger.debug(f'Getting datetimes for {type(data_source).__name__}') try: - all_datetime_indexes.append(data_source.datetime_index()) + + # get datetimes from data source + datetime_index = data_source.datetime_index() + + if interpolate_for_30_minute_data and type(data_source).__name__ == 'GSPDataSource': + # change 30 min data to 5 mins, only for GSP Data + datetime_index = nd_time.fill_30_minutes_timestamps_to_5_minutes(index=datetime_index) + + all_datetime_indexes.append(datetime_index) except NotImplementedError: pass - datetimes = nd_time.intersection_of_datetimeindexes( - all_datetime_indexes) + datetimes = nd_time.intersection_of_datetimeindexes(all_datetime_indexes) del all_datetime_indexes # save memory # Select datetimes that have at least some sunlight border_locations = self.sat_data_source.geospatial_border() - dt_index = nd_time.select_daylight_datetimes( - datetimes=datetimes, locations=border_locations) + dt_index = nd_time.select_daylight_datetimes(datetimes=datetimes, locations=border_locations) # Sanity check assert len(dt_index) > 2 assert utils.is_monotonically_increasing(dt_index) - return dt_index + + if not adjust_for_sequence_length: + return dt_index + + # get t0 datetime which depend on the sequence length in the dataset + t0_datetimes = nd_time.get_t0_datetimes( + datetimes=dt_index, total_seq_len=self._total_seq_len_5_minutes, history_len=self.history_len_5_minutes + ) + + # only select datetimes for half hours, ignore 5 minute timestamps + if interpolate_for_30_minute_data: + t0_datetimes = [t0 for t0 in t0_datetimes if (t0.minute in [0, 30])] + + del dt_index + + return t0_datetimes def _check_has_prepared_data(self): if not self.has_prepared_data: - raise RuntimeError('Must run prepare_data() first!') + raise RuntimeError("Must run prepare_data() first!") diff --git a/nowcasting_dataset/dataset.py b/nowcasting_dataset/dataset/datasets.py similarity index 93% rename from nowcasting_dataset/dataset.py rename to nowcasting_dataset/dataset/datasets.py index 3ce2cbaa..f662c097 100644 --- a/nowcasting_dataset/dataset.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -1,6 +1,8 @@ import pandas as pd from numbers import Number from typing import List, Tuple, Iterable, Callable + +import nowcasting_dataset.consts from nowcasting_dataset import data_sources from dataclasses import dataclass from concurrent import futures @@ -11,15 +13,23 @@ import numpy as np import xarray as xr from nowcasting_dataset import utils as nd_utils -from nowcasting_dataset import example +from nowcasting_dataset.dataset import example import torch from nowcasting_dataset.cloud.gcp import gcp_download_to_local from nowcasting_dataset.cloud.aws import aws_download_to_local +from nowcasting_dataset.consts import GSP_ID, GSP_YIELD, GSP_X_COORDS, GSP_Y_COORDS, GSP_DATETIME_INDEX from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES +""" +This file contains the following classes +NetCDFDataset- torch.utils.data.Dataset: Use for loading pre-made batches +NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches +ContiguousNowcastingDataset - NowcastingDataset +""" + SAT_MEAN = xr.DataArray( data=[ 93.23458, 131.71373, 843.7779, 736.6148, 771.1189, 589.66034, @@ -35,7 +45,7 @@ dims=['sat_variable'], coords={'sat_variable': list(SAT_VARIABLE_NAMES)}).astype(np.float32) -_LOG = logging.getLogger('nowcasting_dataset') +_LOG = logging.getLogger(__name__) class NetCDFDataset(torch.utils.data.Dataset): @@ -115,8 +125,9 @@ def __getitem__(self, batch_idx: int) -> example.Example: 'sat_data', 'sat_x_coords', 'sat_y_coords', 'pv_yield', 'pv_system_id', 'pv_system_row_number', 'pv_system_x_coords', 'pv_system_y_coords', - 'x_meters_center', 'y_meters_center' - ] + list(example.DATETIME_FEATURE_NAMES): + 'x_meters_center', 'y_meters_center', + GSP_ID, GSP_YIELD, GSP_X_COORDS, GSP_Y_COORDS, GSP_DATETIME_INDEX + ] + list(nowcasting_dataset.consts.DATETIME_FEATURE_NAMES): try: batch[key] = netcdf_batch[key] except KeyError: @@ -183,6 +194,7 @@ def per_worker_init(self, worker_id: int) -> None: # Initialise each data_source. for data_source in self.data_sources: + _LOG.debug(f'Opening {type(data_source).__name__}') data_source.open() self._per_worker_init_has_run = True @@ -283,3 +295,5 @@ def worker_init_fn(worker_id): # The NowcastingDataset copy in this worker process. dataset_obj = worker_info.dataset dataset_obj.per_worker_init(worker_info.id) + + diff --git a/nowcasting_dataset/dataset/example.py b/nowcasting_dataset/dataset/example.py new file mode 100644 index 00000000..00bd3de5 --- /dev/null +++ b/nowcasting_dataset/dataset/example.py @@ -0,0 +1,165 @@ +from typing import TypedDict +import pandas as pd +from nowcasting_dataset.consts import * +from numbers import Number + + +class Example(TypedDict): + """Simple class for structuring data for each ML example. + + Using typing.TypedDict gives us several advantages: + 1. Single 'source of truth' for the type and documentation of the fields + in each example. + 2. A static type checker can check the types are correct. + + Instead of TypedDict, we could use typing.NamedTuple, + which would provide runtime checks, but the deal-breaker with Tuples is + that they're immutable so we cannot change the values in the transforms. + """ + + # IMAGES + # Shape: [batch_size,] seq_length, width, height, channel + sat_data: Array + sat_x_coords: Array #: OSGB geo-spatial coordinates. + sat_y_coords: Array + + #: PV yield from all PV systems in the region of interest (ROI). + #: Includes central PV system, which will always be the first entry. + #: shape = [batch_size, ] seq_length, n_pv_systems_per_example + pv_yield: Array + + # PV azimuth and elevation angles i.e where the sun is. + #: shape = [batch_size, ] seq_length, n_pv_systems_per_example + pv_azimuth_angle: Array + pv_elevation_angle: Array + + #: PV identification. + #: shape = [batch_size, ] n_pv_systems_per_example + pv_system_id: Array + pv_system_row_number: Array #: In the range [0, len(pv_metadata)]. + + #: PV system geographical location (in OSGB coords). + #: shape = [batch_size, ] n_pv_systems_per_example + pv_system_x_coords: Array + pv_system_y_coords: Array + pv_datetime_index: Array #: shape = [batch_size, ] seq_length + + # Numerical weather predictions (NWPs) + nwp: Array #: Shape: [batch_size,] channel, seq_length, width, height + nwp_x_coords: Array + nwp_y_coords: Array + + # METADATA + x_meters_center: Number #: In OSGB coordinations + y_meters_center: Number #: In OSGB coordinations + + # Datetimes (abbreviated to "dt") + # At 5-minutes past the hour {0, 5, ..., 55} + # *not* the {4, 9, ..., 59} timings of the satellite imagery. + # Datetimes become Unix epochs (UTC) represented as int64 just before being + # passed into the ML model. + # t0_dt is 'now', the most recent observation. + sat_datetime_index: Array + nwp_target_time: Array + hour_of_day_sin: Array #: Shape: [batch_size,] seq_length + hour_of_day_cos: Array + day_of_year_sin: Array + day_of_year_cos: Array + + #: GSP PV yield from all GSP in the region of interest (ROI). + # : Includes central GSP, which will always be the first entry. + gsp_yield: Array #: shape = [batch_size, ] seq_length, n_gsp_systems_per_example + # GSP identification. + gsp_id: Array #: shape = [batch_size, ] n_pv_systems_per_example + #: GSP geographical location (in OSGB coords). + gsp_x_coords: Array #: shape = [batch_size, ] n_pv_systems_per_example + gsp_y_coords: Array #: shape = [batch_size, ] n_pv_systems_per_example + gsp_datetime_index: Array #: shape = [batch_size, ] seq_length + + # if the centroid type is a GSP, or a PV system + object_at_center: str #: shape = [batch_size, ] + + +def to_numpy(example: Example) -> Example: + for key, value in example.items(): + if isinstance(value, xr.DataArray): + # TODO: Use to_numpy() or as_numpy(), introduced in xarray v0.19? + value = value.data + + if isinstance(value, (pd.Series, pd.DataFrame)): + value = value.values + elif isinstance(value, pd.DatetimeIndex): + value = value.values.astype("datetime64[s]").astype(np.int32) + elif isinstance(value, pd.Timestamp): + value = np.int32(value.timestamp()) + elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.datetime64): + value = value.astype("datetime64[s]").astype(np.int32) + + example[key] = value + return example + + +def validate_example( + data: Example, + seq_len_30_minutes: int, + seq_len_5_minutes: int, + sat_image_size: int = 64, + n_sat_channels: int = 1, + nwp_image_size: int = 0, + n_nwp_channels: int = 1, + n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, + n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE, + + +): + """ + Validate the size and shape of the data + Args: + data: Typed dictionary of the data + seq_len_30_minutes: the length of the sequence for 30 minutely data + seq_len_5_minutes: the length of the sequence for 5 minutely data + sat_image_size: the satellite image size + n_sat_channels: the number of satellite channgles + nwp_image_size: the nwp image size + n_nwp_channels: the number of nwp channels + n_pv_systems_per_example: the number pv systems with nan padding + n_gsp_per_example: the number gsp systems with nan padding + """ + + assert len(data[GSP_ID]) == n_gsp_per_example + n_gsp_system_id = len(data[GSP_ID]) + assert data[GSP_YIELD].shape == (seq_len_30_minutes, n_gsp_system_id) + assert len(data[GSP_X_COORDS]) == n_gsp_system_id + assert len(data[GSP_Y_COORDS]) == n_gsp_system_id + assert len(data[GSP_DATETIME_INDEX]) == seq_len_30_minutes + + assert data[OBJECT_AT_CENTER] == "gsp" + assert type(data["x_meters_center"]) == np.float64 + assert type(data["y_meters_center"]) == np.float64 + + n_pv_systems = len(data[PV_SYSTEM_ID][~np.isnan(data[PV_SYSTEM_ID])]) + + assert len(data[PV_SYSTEM_ID]) == n_pv_systems_per_example + assert data[PV_YIELD].shape == (seq_len_5_minutes, n_pv_systems_per_example) + assert len(data[PV_SYSTEM_X_COORDS]) == n_pv_systems_per_example + assert len(data[PV_SYSTEM_Y_COORDS]) == n_pv_systems_per_example + assert len(data[PV_SYSTEM_ROW_NUMBER][~np.isnan(data[PV_SYSTEM_ROW_NUMBER])]) == n_pv_systems + assert len(data[PV_SYSTEM_ROW_NUMBER][~np.isnan(data[PV_SYSTEM_ROW_NUMBER])]) == n_pv_systems + + if PV_AZIMUTH_ANGLE in data.keys(): + assert data[PV_AZIMUTH_ANGLE].shape == (seq_len_5_minutes, n_pv_systems_per_example) + if PV_AZIMUTH_ANGLE in data.keys(): + assert data[PV_ELEVATION_ANGLE].shape == (seq_len_5_minutes, n_pv_systems_per_example) + + assert data["sat_data"].shape == (seq_len_5_minutes, sat_image_size, sat_image_size, n_sat_channels) + assert len(data["sat_x_coords"]) == sat_image_size + assert len(data["sat_y_coords"]) == sat_image_size + assert len(data["sat_datetime_index"]) == seq_len_5_minutes + + assert data["nwp"].shape == (n_nwp_channels, seq_len_5_minutes, nwp_image_size, nwp_image_size) + assert len(data["nwp_x_coords"]) == nwp_image_size + assert len(data["nwp_y_coords"]) == nwp_image_size + assert len(data["nwp_target_time"]) == seq_len_5_minutes + + for feature in DATETIME_FEATURE_NAMES: + assert len(data[feature]) == seq_len_5_minutes diff --git a/nowcasting_dataset/example.py b/nowcasting_dataset/example.py deleted file mode 100644 index 6d964b28..00000000 --- a/nowcasting_dataset/example.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import TypedDict -import pandas as pd -import xarray as xr -import numpy as np -from nowcasting_dataset.consts import Array -from numbers import Number - - -DATETIME_FEATURE_NAMES = ('hour_of_day_sin', 'hour_of_day_cos', - 'day_of_year_sin', 'day_of_year_cos') - - -class Example(TypedDict): - """Simple class for structuring data for each ML example. - - Using typing.TypedDict gives us several advantages: - 1. Single 'source of truth' for the type and documentation of the fields - in each example. - 2. A static type checker can check the types are correct. - - Instead of TypedDict, we could use typing.NamedTuple, - which would provide runtime checks, but the deal-breaker with Tuples is - that they're immutable so we cannot change the values in the transforms. - """ - # IMAGES - # Shape: [batch_size,] seq_length, width, height, channel - sat_data: Array - sat_x_coords: Array #: OSGB geo-spatial coordinates. - sat_y_coords: Array - - #: PV yield from all PV systems in the region of interest (ROI). - #: Includes central PV system, which will always be the first entry. - #: shape = [batch_size, ] seq_length, n_pv_systems_per_example - pv_yield: Array - - # PV azimuth and elevation angles i.e where the sun is. - #: shape = [batch_size, ] seq_length, n_pv_systems_per_example - pv_azimuth_angle: Array - pv_elevation_angle: Array - - #: PV identification. - #: shape = [batch_size, ] n_pv_systems_per_example - pv_system_id: Array - pv_system_row_number: Array #: In the range [0, len(pv_metadata)]. - - #: PV system geographical location (in OSGB coords). - #: shape = [batch_size, ] n_pv_systems_per_example - pv_system_x_coords: Array - pv_system_y_coords: Array - - # Numerical weather predictions (NWPs) - nwp: Array #: Shape: [batch_size,] channel, seq_length, width, height - nwp_x_coords: Array - nwp_y_coords: Array - - # METADATA - x_meters_center: Number #: In OSGB coordinations - y_meters_center: Number #: In OSGB coordinations - - # Datetimes (abbreviated to "dt") - # At 5-minutes past the hour {0, 5, ..., 55} - # *not* the {4, 9, ..., 59} timings of the satellite imagery. - # Datetimes become Unix epochs (UTC) represented as int64 just before being - # passed into the ML model. - # t0_dt is 'now', the most recent observation. - sat_datetime_index: Array - nwp_target_time: Array - hour_of_day_sin: Array #: Shape: [batch_size,] seq_length - hour_of_day_cos: Array - day_of_year_sin: Array - day_of_year_cos: Array - - -def to_numpy(example: Example) -> Example: - for key, value in example.items(): - if isinstance(value, xr.DataArray): - # TODO: Use to_numpy() or as_numpy(), introduced in xarray v0.19? - value = value.data - - if isinstance(value, (pd.Series, pd.DataFrame)): - value = value.values - elif isinstance(value, pd.DatetimeIndex): - value = value.values.astype('datetime64[s]').astype(np.int32) - elif isinstance(value, pd.Timestamp): - value = np.int32(value.timestamp()) - elif (isinstance(value, np.ndarray) and - np.issubdtype(value.dtype, np.datetime64)): - value = value.astype('datetime64[s]').astype(np.int32) - - example[key] = value - return example diff --git a/nowcasting_dataset/geospatial.py b/nowcasting_dataset/geospatial.py index 0e66b0e5..b83f4eb4 100644 --- a/nowcasting_dataset/geospatial.py +++ b/nowcasting_dataset/geospatial.py @@ -16,6 +16,7 @@ # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses # latitude and longitude. WGS84 = 4326 +WGS84_CRS = f"EPSG:{WGS84}" class Transformers: diff --git a/nowcasting_dataset/square.py b/nowcasting_dataset/square.py index 0af57211..ddf9faf8 100644 --- a/nowcasting_dataset/square.py +++ b/nowcasting_dataset/square.py @@ -1,12 +1,14 @@ -from typing import NamedTuple +from typing import NamedTuple, Union from numbers import Number +from nowcasting_dataset.consts import Array + class BoundingBox(NamedTuple): - top: Number - bottom: Number - left: Number - right: Number + top: Union[Number, float] + bottom: Union[Number, float] + left: Union[Number, float] + right: Union[Number, float] class Square: @@ -27,3 +29,21 @@ def bounding_box_centered_on( bottom=y_meters_center - self._half_size_meters, left=x_meters_center - self._half_size_meters, right=x_meters_center + self._half_size_meters) + + +def get_bounding_box_mask(bounding_box: BoundingBox, x: Array, y: Array)-> Array: + """ + Get boundary box mask from x and y locations. I.e are the x,y coords in the boundaring box + Args: + bounding_box: Bounding box + x: x coordinates + y: y coordinates + + Returns: list of booleans if the x and y coordinates are in the bounding box + + """ + mask = ( + (x >= bounding_box.left) & (x <= bounding_box.right) & (y >= bounding_box.bottom) & ( + y <= bounding_box.top) + ) + return mask \ No newline at end of file diff --git a/nowcasting_dataset/time.py b/nowcasting_dataset/time.py index 7db16393..64cc0aae 100644 --- a/nowcasting_dataset/time.py +++ b/nowcasting_dataset/time.py @@ -2,12 +2,17 @@ import numpy as np from typing import Iterable, Tuple, List from nowcasting_dataset import geospatial, utils -from nowcasting_dataset.example import Example +from nowcasting_dataset.dataset.example import Example import warnings import pvlib +import logging + + +logger = logging.getLogger(__name__) FIVE_MINUTES = pd.Timedelta('5 minutes') +THIRTY_MINUTES = pd.Timedelta('30 minutes') def select_daylight_datetimes( @@ -57,7 +62,7 @@ def intersection_of_datetimeindexes( def get_start_datetimes( datetimes: pd.DatetimeIndex, total_seq_len: int, - max_gap: pd.Timedelta = FIVE_MINUTES) -> pd.DatetimeIndex: + max_gap: pd.Timedelta = THIRTY_MINUTES) -> pd.DatetimeIndex: """Returns a datetime index of valid start datetimes. Valid start datetimes are those where there is certain to be @@ -97,6 +102,8 @@ def get_start_datetimes( start_dt_index.append(datetimes[start_i:end_i]) start_i = next_start_i + assert len(start_dt_index) > 0 + return pd.DatetimeIndex(np.concatenate(start_dt_index)) @@ -104,17 +111,36 @@ def get_t0_datetimes( datetimes: pd.DatetimeIndex, total_seq_len: int, history_len: int, + minute_delta: int = 5, max_gap: pd.Timedelta = FIVE_MINUTES) -> pd.DatetimeIndex: + """ + Get datetimes for ML learning batches. T0 refers to the time 'now'. + Args: + datetimes: list of datetimes when data is available + total_seq_len: total sequence length of data for ml model + history_len: the number of historic timestemps + minute_delta: the amount of minutes in one time step + max_gap: The maximum allowed gap in the datetimes for it to be valid + + Returns: Datetimes that ml learning data can be built around. + + """ + + logger.debug('Getting t0 datetimes') + start_datetimes = get_start_datetimes( datetimes=datetimes, total_seq_len=total_seq_len, max_gap=max_gap) - history_dur = timesteps_to_duration(history_len) + + logger.debug('Adding history during to t0 datetimes') + history_dur = timesteps_to_duration(history_len, minute_delta=minute_delta) t0_datetimes = start_datetimes + history_dur + return t0_datetimes -def timesteps_to_duration(n_timesteps: int) -> pd.Timedelta: +def timesteps_to_duration(n_timesteps: int, minute_delta: int = 5) -> pd.Timedelta: assert n_timesteps >= 0 - return pd.Timedelta(n_timesteps * 5, unit='minutes') + return pd.Timedelta(n_timesteps * minute_delta, unit='minutes') def datetime_features(index: pd.DatetimeIndex) -> pd.DataFrame: @@ -133,3 +159,49 @@ def datetime_features_in_example(index: pd.DatetimeIndex) -> Example: for col_name, series in dt_features.iteritems(): example[col_name] = series return example + + +def fill_30_minutes_timestamps_to_5_minutes(index: pd.DatetimeIndex) -> pd.DatetimeIndex: + """ + Fill a 30 minute index with 5 minute timestamps too. Note any gaps in 30 mins are not filled + """ + + # resample index to 5 mins + index_5 = pd.Series(0, index=index).resample('5T') + + # calculate forward fill and backward fill + index_5_ffill = index_5.ffill(limit=5) + index_5_bfill = index_5.bfill(limit=5) + + # Time forward fill and backward together. + # This means there will be NaNs if the original index is not in surrounding the values + # for example: + # index = [00:00:00, 01:00:00, 01:30:00, 02:00:00] + # + # index_5 ffill bfill ffill*bfill + # 00:00:00 0 0 0 + # 00:05:00 0 NaN NaN + # 00:10:00 0 NaN NaN + # 00:15:00 0 NaN NaN + # 00:20:00 0 NaN NaN + # 00:25:00 0 NaN NaN + # 00:30:00 NaN NaN NaN + # 00:35:00 NaN 0 NaN + # 00:40:00 NaN 0 NaN + # 00:45:00 NaN 0 NaN + # 00:50:00 NaN 0 NaN + # 00:55:00 NaN 0 NaN + # 01:00:00 0 0 0 + # 01:05:00 0 0 0 + # 01:10:00 0 0 0 + # 01:15:00 0 0 0 + # 01:20:00 0 0 0 + # 01:25:00 0 0 0 + # 01:30:00 0 0 0 + # ..... + # 02:00:00 0 0 0 + index_with_gaps = index_5_ffill * index_5_bfill + + # drop nans and take index + return index_with_gaps.dropna().index + diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index 4edb49d4..e6c0a93b 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -1,9 +1,15 @@ +import logging import numpy as np import pandas as pd from nowcasting_dataset.consts import Array import fsspec.asyn +from typing import List from pathlib import Path import hashlib +from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.cloud.gcp import get_all_filenames_in_path + +logger = logging.getLogger(__name__) def set_fsspec_for_multiprocess() -> None: @@ -20,7 +26,7 @@ def is_monotonically_increasing(a: Array) -> bool: assert a is not None assert len(a) > 0 if isinstance(a, pd.DatetimeIndex): - a = a.astype(int) + a = a.view(int) a = np.asarray(a) return np.all(np.diff(a) > 0) @@ -92,3 +98,68 @@ def get_netcdf_filename(batch_idx: int, add_hash:bool = False) -> Path: def pad_nans(array, pad_width) -> np.ndarray: array = array.astype(np.float32) return np.pad(array, pad_width, constant_values=np.NaN) + + +def pad_data(data: Example, pad_size: int, one_dimensional_arrays: List[str], two_dimensional_arrays: List[str]) -> Example: + """ + Pad (if necessary) so returned arrays are always of size + + data has two types of arrays in it, one dimensional arrays and two dimensional arrays + the one dimensional arrays are padded in that dimension + the two dimensional arrays are padded in the second dimension + + Args: + data: typed dictionary of data objects + pad_size: the maount that should be padded + one_dimensional_arrays: list of data items that should be padded by one dimension + two_dimensional_arrays: list of data tiems that should be padded in the third dimension (and more) + + Returns: + + """ + # Pad (if necessary) so returned arrays are always of size + pad_shape = (0, pad_size) # (before, after) + + for name in one_dimensional_arrays: + data[name] = pad_nans(data[name], pad_width=pad_shape) + + for variable in two_dimensional_arrays: + data[variable] = pad_nans(data[variable], pad_width=((0, 0), pad_shape)) # (axis0, axis1) + + return data + + +def get_maximum_batch_id_from_gcs(remote_path: str): + """ + Get the last batch id from gcs. + Args: + remote_path: the remote path folder to look in. Warning currently only works for GCS + + Returns: the maximum batch id of data in the remote folder + + """ + + logger.debug(f'Looking for maximum batch id in {remote_path}') + + filenames = get_all_filenames_in_path(remote_path=remote_path) + + # just take filename + filenames = [filename.split('/')[-1] for filename in filenames] + + # remove suffix + filenames = [filename.split('.')[0] for filename in filenames] + + # change to integer + batch_indexes = [int(filename) for filename in filenames if len(filename) > 0] + + # if there is no files, return None + if len(batch_indexes) == 0: + logger.debug(f'Did not find any files in {remote_path}') + return None + + # get the maximum batch id + maximum_batch_id = max(batch_indexes) + logger.debug(f'Found maximum of batch it of {maximum_batch_id} in {remote_path}') + + return maximum_batch_id + diff --git a/requirements.txt b/requirements.txt index e33ab98e..fd7b986c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy pandas +geopandas matplotlib zarr xarray @@ -22,4 +23,6 @@ moto neptune-client pydantic pytest-cov -plotly \ No newline at end of file +plotly +tqdm +git+https://github.com/SheffieldSolar/PV_Live-API diff --git a/scripts/get_raw_pv_gsp_data.py b/scripts/get_raw_pv_gsp_data.py new file mode 100644 index 00000000..753b70b3 --- /dev/null +++ b/scripts/get_raw_pv_gsp_data.py @@ -0,0 +1,58 @@ +############ +# Pull raw pv gsp data from Sheffield Solar +# +# 2021-09-01 +# Peter Dudfield +# +# The data is about 1MB for a month of data +############ +from datetime import datetime +import pytz +import yaml +import os +import numcodecs + +from nowcasting_dataset.data_sources.gsp.pvlive import load_pv_gsp_raw_data_from_pvlive +from pathlib import Path +from nowcasting_dataset.cloud.local import delete_all_files_in_temp_path +from nowcasting_dataset.cloud.gcp import gcp_upload_and_delete_local_files +import logging + +logging.basicConfig() +logging.getLogger().setLevel(logging.DEBUG) +logging.getLogger("urllib3").setLevel(logging.WARNING) + +start = datetime(2018, 1, 1, tzinfo=pytz.utc) +end = datetime(2021, 1, 1, tzinfo=pytz.utc) +gcp_path = "gs://solar-pv-nowcasting-data/PV/GSP/v1" + +config = {"start": start, "end": end, "gcp_path": gcp_path} + +# format local temp folder +LOCAL_TEMP_PATH = Path("~/temp/").expanduser() +delete_all_files_in_temp_path(path=LOCAL_TEMP_PATH) + +# get data +data_df = load_pv_gsp_raw_data_from_pvlive(start=start, end=end) + +# pivot to index as datetime_gmt, and columns as gsp_id +data_df = data_df.pivot(index='datetime_gmt', columns='gsp_id', values='generation_mw') +data_df.columns = [str(col) for col in data_df.columns] + +# change to xarray +data_xarray = data_df.to_xarray() + +# save config to file +with open(os.path.join(LOCAL_TEMP_PATH, "configuration.yaml"), "w+") as f: + yaml.dump(config, f, allow_unicode=True) + +# Make encoding +encoding = { + var: {'compressor': numcodecs.Blosc(cname="zstd", clevel=5)} for var in data_xarray.data_vars +} + +# save data to file +data_xarray.to_zarr(os.path.join(LOCAL_TEMP_PATH, "pv_gsp.zarr"), mode="w",encoding=encoding) + +# upload to gcp +gcp_upload_and_delete_local_files(dst_path=gcp_path, local_path=LOCAL_TEMP_PATH) diff --git a/scripts/prepare_ml_training_data.py b/scripts/prepare_ml_training_data.py index 978bdb0b..e5977d11 100755 --- a/scripts/prepare_ml_training_data.py +++ b/scripts/prepare_ml_training_data.py @@ -8,6 +8,9 @@ DST_TRAIN_PATH and DST_VALIDATION_PATH, and create the LOCAL_TEMP_PATH. Note that all files will be deleted from LOCAL_TEMP_PATH when this script starts up. + +Currently caluclating azimuth and elevation angles, takes about 15 mins for 2548 PV systems, for about 1 year + """ from nowcasting_dataset.cloud.gcp import check_path_exists @@ -19,18 +22,15 @@ from nowcasting_dataset.config.load import load_yaml_configuration from nowcasting_dataset.config.save import save_configuration_to_cloud -from nowcasting_dataset.datamodule import NowcastingDataModule -from nowcasting_dataset.example import Example, DATETIME_FEATURE_NAMES +from nowcasting_dataset.dataset.datamodule import NowcastingDataModule +from nowcasting_dataset.dataset.batch import write_batch_locally from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES from nowcasting_dataset.data_sources.nwp_data_source import NWP_VARIABLE_NAMES +from nowcasting_dataset.utils import get_maximum_batch_id_from_gcs from pathlib import Path -import numpy as np -import xarray as xr import torch import os -from typing import List, Optional -from nowcasting_dataset.utils import get_netcdf_filename import neptune.new as neptune from neptune.new.integrations.python_logger import NeptuneHandler @@ -38,10 +38,12 @@ logging.basicConfig(format='%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s') _LOG = logging.getLogger("nowcasting_dataset") -_LOG.setLevel(logging.DEBUG) +_LOG.setLevel(logging.INFO) + +logging.getLogger("nowcasting_dataset.data_source").setLevel(logging.WARNING) # load configuration, this can be changed to a different filename as needed -filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), 'config', 'example.yaml') +filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), 'config', 'gcp.yaml') config = load_yaml_configuration(filename) # set the gcs bucket name @@ -57,6 +59,8 @@ # Numerical weather predictions NWP_BASE_PATH = BUCKET / config.input_data.npw_base_path +# GSP data +GSP_FILENAME = BUCKET / config.input_data.gsp_filename DST_NETCDF4_PATH = config.output_data.filepath DST_TRAIN_PATH = os.path.join(DST_NETCDF4_PATH, 'train') @@ -64,7 +68,7 @@ LOCAL_TEMP_PATH = Path('~/temp/').expanduser() UPLOAD_EVERY_N_BATCHES = 16 -CLOUD = "aws" # either gcp or aws +CLOUD = "gcp" # either gcp or aws # Necessary to avoid "RuntimeError: receieved 0 items of ancdata". See: # https://discuss.pytorch.org/t/runtimeerror-received-0-items-of-ancdata/4999/2 @@ -72,19 +76,32 @@ def get_data_module(): + num_workers = 4 + + # get the batch id already made + maximum_batch_id_train = get_maximum_batch_id_from_gcs(f"gs://{DST_TRAIN_PATH}") + maximum_batch_id_validation = get_maximum_batch_id_from_gcs(f"gs://{DST_TRAIN_PATH}") + + if maximum_batch_id_train is None: + maximum_batch_id_train = 0 + + if maximum_batch_id_validation is None: + maximum_batch_id_validation = 0 + data_module = NowcastingDataModule( batch_size=config.process.batch_size, - history_len=config.process.history_length, #: Number of timesteps of history, not including t0. - forecast_len=config.process.forecast_length, #: Number of timesteps of forecast. + history_minutes=config.process.history_minutes, #: Number of minutes of history, not including t0. + forecast_minutes=config.process.forecast_minutes, #: Number of minutes of forecast. image_size_pixels=config.process.image_size_pixels, nwp_channels=NWP_VARIABLE_NAMES, sat_channels=SAT_VARIABLE_NAMES, - pv_power_filename=PV_DATA_FILENAME, + pv_power_filename=f"gs://{PV_DATA_FILENAME}", pv_metadata_filename=f"gs://{PV_METADATA_FILENAME}", sat_filename=f"gs://{SAT_FILENAME}", nwp_base_path=f"gs://{NWP_BASE_PATH}", + gsp_filename=f"gs://{GSP_FILENAME}", pin_memory=True, #: Passed to DataLoader. - num_workers=6, #: Passed to DataLoader. + num_workers=num_workers, #: Passed to DataLoader. prefetch_factor=8, #: Passed to DataLoader. n_samples_per_timestep=8, #: Passed to NowcastingDataset n_training_batches_per_epoch=25_008, # Add pre-fetch factor! @@ -92,8 +109,8 @@ def get_data_module(): collate_fn=lambda x: x, convert_to_numpy=False, #: Leave data as Pandas / Xarray for pre-preparing. normalise_sat=False, - skip_n_train_batches=0, - skip_n_validation_batches=0, + skip_n_train_batches=maximum_batch_id_train // num_workers, + skip_n_validation_batches=maximum_batch_id_validation // num_workers, ) _LOG.info("prepare_data()") data_module.prepare_data() @@ -102,116 +119,6 @@ def get_data_module(): return data_module -def coord_to_range(da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32) -> xr.DataArray: - # TODO: Actually, I think this is over-complicated? I think we can - # just strip off the 'coord' from the dimension. - coord = da[dim] - da[dim] = np.arange(len(coord), dtype=dtype) - if prefix is not None: - da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim]) - return da - - -def batch_to_dataset(batch: List[Example]) -> xr.Dataset: - """Concat all the individual fields in an Example into a single Dataset. - - Args: - batch: List of Example objects, which together constitute a single batch. - """ - datasets = [] - for i, example in enumerate(batch): - try: - individual_datasets = [] - example_dim = {"example": np.array([i], dtype=np.int32)} - for name in ["sat_data", "nwp"]: - ds = example[name].to_dataset(name=name) - short_name = name.replace("_data", "") - if name == "nwp": - ds = ds.rename({"target_time": "time"}) - for dim in ["time", "x", "y"]: - ds = coord_to_range(ds, dim, prefix=short_name) - ds = ds.rename( - { - "variable": f"{short_name}_variable", - "x": f"{short_name}_x", - "y": f"{short_name}_y", - } - ) - individual_datasets.append(ds) - - # Datetime features - for name in DATETIME_FEATURE_NAMES: - ds = example[name].rename(name).to_xarray().to_dataset().rename({"index": "time"}) - ds = coord_to_range(ds, "time", prefix=None) - individual_datasets.append(ds) - - # PV - pv_yield = xr.DataArray(example["pv_yield"], dims=["time", "pv_system"]) - pv_yield = pv_yield.to_dataset(name="pv_yield") - n_pv_systems = len(example["pv_system_id"]) - # This will expand all dataarrays to have an 'example' dim. - # 0D - for name in ["x_meters_center", "y_meters_center"]: - try: - pv_yield[name] = xr.DataArray([example[name]], coords=example_dim, dims=["example"]) - except Exception as e: - _LOG.error(f'Could not make pv_yield data for {name} with example_dim={example_dim}') - if name not in example.keys(): - _LOG.error(f'{name} not in data keys: {example.keys()}') - _LOG.error(e) - raise Exception - - # 1D - for name in ["pv_system_id", "pv_system_row_number", "pv_system_x_coords", "pv_system_y_coords"]: - pv_yield[name] = xr.DataArray( - example[name][None, :], - coords=example_dim | {"pv_system": np.arange(n_pv_systems, dtype=np.int32)}, - dims=["example", "pv_system"], - ) - - individual_datasets.append(pv_yield) - - # Merge - merged_ds = xr.merge(individual_datasets) - datasets.append(merged_ds) - except Exception as e: - print(e) - _LOG.error(e) - raise Exception - - return xr.concat(datasets, dim="example") - - -def fix_dtypes(concat_ds): - ds_dtypes = { - "example": np.int32, - "sat_x_coords": np.int32, - "sat_y_coords": np.int32, - "nwp": np.float32, - "nwp_x_coords": np.float32, - "nwp_y_coords": np.float32, - "pv_system_id": np.float32, - "pv_system_row_number": np.float32, - "pv_system_x_coords": np.float32, - "pv_system_y_coords": np.float32, - } - - for name, dtype in ds_dtypes.items(): - concat_ds[name] = concat_ds[name].astype(dtype) - - assert concat_ds["sat_data"].dtype == np.int16 - return concat_ds - - -def write_batch_locally(batch: List[Example], batch_i: int): - dataset = batch_to_dataset(batch) - dataset = fix_dtypes(dataset) - encoding = {name: {"compression": "lzf"} for name in dataset.data_vars} - filename = get_netcdf_filename(batch_i) - local_filename = LOCAL_TEMP_PATH / filename - dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) - - def iterate_over_dataloader_and_write_to_disk(dataloader: torch.utils.data.DataLoader, dst_path: str): _LOG.info("Getting first batch") for batch_i, batch in enumerate(dataloader): diff --git a/tests/config/test.yaml b/tests/config/test.yaml new file mode 100644 index 00000000..2e763415 --- /dev/null +++ b/tests/config/test.yaml @@ -0,0 +1,33 @@ +general: + description: example configuration + name: example +input_data: + bucket: solar-pv-nowcasting-data + npw_base_path: tests/data/nwp_data/test.zarr + satelite_filename: tests/data/sat_data.zarr + solar_pv_data_filename: tests/data/pv_data/test.nc + solar_pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv + solar_pv_path: tests/data/pv_data + gsp_filename: tests/data/gsp/test.zarr +output_data: + filepath: solar-pv-nowcasting-data/prepared_ML_training_data/v5/ +process: + batch_size: 32 + forecast_minutes: 60 + history_minutes: 30 + image_size_pixels: 64 + nwp_channels: + - t + - dswrf + - prate + - r + - sde + - si10 + - vis + - lcc + - mcc + - hcc + prcesion: 16 + sat_channels: + - HRV + val_check_interval: 1000 diff --git a/tests/data/gsp/test.zarr/.zattrs b/tests/data/gsp/test.zarr/.zattrs new file mode 100644 index 00000000..9e26dfee --- /dev/null +++ b/tests/data/gsp/test.zarr/.zattrs @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/tests/data/gsp/test.zarr/.zgroup b/tests/data/gsp/test.zarr/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/tests/data/gsp/test.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/tests/data/gsp/test.zarr/.zmetadata b/tests/data/gsp/test.zarr/.zmetadata new file mode 100644 index 00000000..a6cf4ed1 --- /dev/null +++ b/tests/data/gsp/test.zarr/.zmetadata @@ -0,0 +1,536 @@ +{ + "metadata": { + ".zattrs": {}, + ".zgroup": { + "zarr_format": 2 + }, + "1/.zarray": { + "chunks": [ + 145 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dtype": " 0 + + +def test_gsp_pv_data_source_get_batch(): + local_path = os.path.dirname(nowcasting_dataset.__file__) + '/..' + + gsp = GSPDataSource(filename=f"{local_path}/tests/data/gsp/test.zarr", + start_dt=datetime(2019, 1, 1), + end_dt=datetime(2019, 1, 2), + history_minutes=30, + forecast_minutes=60, + sample_period_minutes=30, + convert_to_numpy=True, + image_size_pixels=64, + meters_per_pixel=2000) + + batch_size = 10 + + x_locations, y_locations = gsp.get_locations_for_batch(t0_datetimes=gsp.gsp_power.index[0:batch_size]) + + batch = gsp.get_batch(t0_datetimes=gsp.gsp_power.index[batch_size:2*batch_size], + x_locations=x_locations[0:batch_size], + y_locations=y_locations[0:batch_size]) + + assert len(batch) == batch_size + assert len(batch[0]['gsp_yield']) == 4 + assert len(batch[0]['gsp_id']) == len(batch[0]['gsp_x_coords']) + assert len(batch[1]['gsp_x_coords']) == len(batch[1]['gsp_y_coords']) + assert len(batch[2]['gsp_x_coords']) > 0 + + +def test_get_gsp_metadata_from_eso(): + """ + Test to get the gsp metadata from eso. This should take ~1 second. + @return: + """ + metadata = get_gsp_metadata_from_eso() + + assert metadata['gsp_id'].is_unique == 1 + + assert isinstance(metadata, pd.DataFrame) + assert len(metadata) > 100 + assert "gnode_name" in metadata.columns + assert "gnode_lat" in metadata.columns + assert "gnode_lon" in metadata.columns + + +def test_get_pv_gsp_shape(): + """ + Test to get the gsp metadata from eso. This should take ~1 second. + @return: + """ + + gsp_shapes = get_gsp_shape_from_eso() + + assert isinstance(gsp_shapes, gpd.GeoDataFrame) + assert "RegionID" in gsp_shapes.columns + assert "RegionName" in gsp_shapes.columns + assert "geometry" in gsp_shapes.columns + + +def test_load_gsp_raw_data_from_pvlive_one_gsp_one_day(): + """ + Test that one gsp system data can be loaded, just for one day + """ + + start = datetime(2019, 1, 1, tzinfo=pytz.utc) + end = datetime(2019, 1, 2, tzinfo=pytz.utc) + + gsp_pv_df = load_pv_gsp_raw_data_from_pvlive(start=start, end=end, number_of_gsp=1) + + assert isinstance(gsp_pv_df, pd.DataFrame) + assert len(gsp_pv_df) == (48 + 1) + assert "datetime_gmt" in gsp_pv_df.columns + assert "generation_mw" in gsp_pv_df.columns + + +def test_load_gsp_raw_data_from_pvlive_one_gsp(): + """ + Test that one gsp system data can be loaded + """ + + start = datetime(2019, 1, 1, tzinfo=pytz.utc) + end = datetime(2019, 3, 1, tzinfo=pytz.utc) + + gsp_pv_df = load_pv_gsp_raw_data_from_pvlive(start=start, end=end, number_of_gsp=1) + + assert isinstance(gsp_pv_df, pd.DataFrame) + assert len(gsp_pv_df) == (48 * 59 + 1) + # 30 days in january, 29 days in february, plus one for the first timestamp in march + assert "datetime_gmt" in gsp_pv_df.columns + assert "generation_mw" in gsp_pv_df.columns + + +def test_load_gsp_raw_data_from_pvlive_many_gsp(): + """ + Test that one gsp system data can be loaded + """ + + start = datetime(2019, 1, 1, tzinfo=pytz.utc) + end = datetime(2019, 1, 2, tzinfo=pytz.utc) + + gsp_pv_df = load_pv_gsp_raw_data_from_pvlive(start=start, end=end, number_of_gsp=10) + + assert isinstance(gsp_pv_df, pd.DataFrame) + assert len(gsp_pv_df) == (48 + 1) * 10 + assert "datetime_gmt" in gsp_pv_df.columns + assert "generation_mw" in gsp_pv_df.columns diff --git a/tests/data_sources/test_data_source.py b/tests/data_sources/test_data_source.py new file mode 100644 index 00000000..00405492 --- /dev/null +++ b/tests/data_sources/test_data_source.py @@ -0,0 +1,12 @@ +from nowcasting_dataset.data_sources.data_source import ImageDataSource + + +def test_image_data_source(): + + _ = ImageDataSource( + image_size_pixels=64, + meters_per_pixel=2000, + history_minutes=30, + forecast_minutes=60, + convert_to_numpy=True, + ) diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index 4848f633..356ba660 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -12,8 +12,8 @@ def test_nwp_data_source_init(): _ = NWPDataSource( filename=NWP_FILENAME, - history_len=6, - forecast_len=12, + history_minutes=30, + forecast_minutes=60, convert_to_numpy=True, n_timesteps_per_batch=8, ) @@ -28,8 +28,8 @@ def test_nwp_data_source_open(): nwp = NWPDataSource( filename=NWP_FILENAME, - history_len=6, - forecast_len=12, + history_minutes=30, + forecast_minutes=60, convert_to_numpy=True, n_timesteps_per_batch=8, ) @@ -46,8 +46,8 @@ def test_nwp_data_source_batch(): nwp = NWPDataSource( filename=NWP_FILENAME, - history_len=6, - forecast_len=12, + history_minutes=30, + forecast_minutes=60, convert_to_numpy=True, n_timesteps_per_batch=8, ) diff --git a/tests/data_sources/test_pv_data_source.py b/tests/data_sources/test_pv_data_source.py index 465929e1..a8b80e3e 100644 --- a/tests/data_sources/test_pv_data_source.py +++ b/tests/data_sources/test_pv_data_source.py @@ -1,6 +1,7 @@ import pandas as pd +import numpy as np -from nowcasting_dataset.data_sources.pv_data_source import PVDataSource, drop_pv_systems_which_produce_overnight +from nowcasting_dataset.data_sources.pv_data_source import PVDataSource, drop_pv_systems_which_produce_overnight, calculate_azimuth_and_elevation_all_pv_systems from datetime import datetime import nowcasting_dataset import os @@ -18,8 +19,8 @@ def test_get_example_and_batch(): PV_METADATA_FILENAME = f"{path}/../tests/data/pv_metadata/UK_PV_metadata.csv" pv_data_source = PVDataSource( - history_len=6, - forecast_len=12, + history_minutes=30, + forecast_minutes=60, convert_to_numpy=True, image_size_pixels=64, meters_per_pixel=2000, @@ -49,8 +50,8 @@ def test_get_example_and_batch_azimuth(): PV_METADATA_FILENAME = f"{path}/../tests/data/pv_metadata/UK_PV_metadata.csv" pv_data_source = PVDataSource( - history_len=6, - forecast_len=12, + history_minutes=30, + forecast_minutes=60, convert_to_numpy=True, image_size_pixels=64, meters_per_pixel=2000, @@ -75,3 +76,20 @@ def test_drop_pv_systems_which_produce_overnight(): pv_power = pd.DataFrame(index=pd.date_range('2010-01-01', '2010-01-02', freq='5 min')) _ = drop_pv_systems_which_produce_overnight(pv_power=pv_power) + + +def test_calculate_azimuth_and_elevation_all_pv_systems(): + datestamps = pd.date_range('2010-01-01', '2010-01-02', freq='5 min') + N = 2548 + pv_metadata = pd.DataFrame(index=range(0, N)) + + pv_metadata['latitude'] = np.random.random(N) + pv_metadata['longitude'] = np.random.random(N) + pv_metadata['name'] = np.random.random(N) + + azimuth, elevation = calculate_azimuth_and_elevation_all_pv_systems(datestamps=datestamps, pv_metadata=pv_metadata) + + assert len(azimuth) == len(datestamps) + assert len(azimuth.columns) == N + + # 49 * 2548 = 100,000 takes 26 seconds diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index a200654f..632667e9 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -1,8 +1,24 @@ +import logging +import os from pathlib import Path -from nowcasting_dataset import datamodule -import pytest + import numpy as np import pandas as pd +import xarray as xr +import pytest + +import nowcasting_dataset +from nowcasting_dataset.dataset import datamodule +from nowcasting_dataset.config.load import load_yaml_configuration +from nowcasting_dataset.dataset.datamodule import NowcastingDataModule +from nowcasting_dataset.dataset.example import validate_example +from nowcasting_dataset.dataset.batch import batch_to_dataset +from nowcasting_dataset.dataset.example import Example + +logging.basicConfig(format='%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s') +_LOG = logging.getLogger("nowcasting_dataset") +_LOG.setLevel(logging.DEBUG) + @pytest.fixture @@ -22,7 +38,8 @@ def test_get_daylight_datetime_index( with pytest.raises(RuntimeError): nowcasting_datamodule._get_datetimes() nowcasting_datamodule.prepare_data() - datetimes = nowcasting_datamodule._get_datetimes() + datetimes = nowcasting_datamodule._get_datetimes(interpolate_for_30_minute_data=False, + adjust_for_sequence_length=False) assert isinstance(datetimes, pd.DatetimeIndex) if not use_cloud_data: correct_datetimes = pd.date_range( @@ -38,3 +55,111 @@ def test_setup( nowcasting_datamodule.setup() nowcasting_datamodule.prepare_data() nowcasting_datamodule.setup() + + +def test_data_module(): + + local_path = os.path.join(os.path.dirname(nowcasting_dataset.__file__), '../') + + # load configuration, this can be changed to a different filename as needed + filename = os.path.join(local_path, 'tests', 'config', 'test.yaml') + config = load_yaml_configuration(filename) + + data_module = NowcastingDataModule( + batch_size=config.process.batch_size, + history_minutes=30, #: Number of timesteps of history, not including t0. + forecast_minutes=60, #: Number of timesteps of forecast. + image_size_pixels=config.process.image_size_pixels, + nwp_channels=config.process.nwp_channels, + sat_channels=config.process.sat_channels, # reduced for test data + pv_power_filename=config.input_data.solar_pv_data_filename, + pv_metadata_filename=config.input_data.solar_pv_metadata_filename, + sat_filename=config.input_data.satelite_filename, + nwp_base_path=config.input_data.npw_base_path, + gsp_filename=config.input_data.gsp_filename, + pin_memory=True, #: Passed to DataLoader. + num_workers=0, #: Passed to DataLoader. + prefetch_factor=8, #: Passed to DataLoader. + n_samples_per_timestep=16, #: Passed to NowcastingDataset + n_training_batches_per_epoch=200, # Add pre-fetch factor! + n_validation_batches_per_epoch=200, + collate_fn=lambda x: x, + convert_to_numpy=False, #: Leave data as Pandas / Xarray for pre-preparing. + normalise_sat=False, + skip_n_train_batches=0, + skip_n_validation_batches=0, + train_validation_percentage_split=50, + pv_load_azimuth_and_elevation=True, + ) + + _LOG.info("prepare_data()") + data_module.prepare_data() + _LOG.info("setup()") + data_module.setup() + + data_generator = iter(data_module.train_dataset) + batch = next(data_generator) + + assert len(batch) == config.process.batch_size + + for key in list(Example.__annotations__.keys()): + assert key in batch[0].keys() + + seq_len_30_minutes = 4 # 30 minutes history, 60 minutes in the future plus now, is 4) + seq_len_5_minutes = 19 # 30 minutes history (=6), 60 minutes in the future (=12) plus now, is 19) + + for x in batch: + validate_example(data=x, + n_nwp_channels=len(config.process.nwp_channels), + nwp_image_size=0,# TODO why is this zero + n_sat_channels=len(config.process.sat_channels), + sat_image_size=config.process.image_size_pixels, + seq_len_30_minutes=seq_len_30_minutes, + seq_len_5_minutes=seq_len_5_minutes) + + +def test_batch_to_batch_to_dataset(): + + local_path = os.path.join(os.path.dirname(nowcasting_dataset.__file__), '../') + + # load configuration, this can be changed to a different filename as needed + filename = os.path.join(local_path, 'tests', 'config', 'test.yaml') + config = load_yaml_configuration(filename) + + data_module = NowcastingDataModule( + batch_size=config.process.batch_size, + history_minutes=30, #: Number of timesteps of history, not including t0. + forecast_minutes=60, #: Number of timesteps of forecast. + image_size_pixels=config.process.image_size_pixels, + nwp_channels=config.process.nwp_channels, + sat_channels=config.process.sat_channels, # reduced for test data + pv_power_filename=config.input_data.solar_pv_data_filename, + pv_metadata_filename=config.input_data.solar_pv_metadata_filename, + sat_filename=config.input_data.satelite_filename, + nwp_base_path=config.input_data.npw_base_path, + gsp_filename=config.input_data.gsp_filename, + pin_memory=True, #: Passed to DataLoader. + num_workers=0, #: Passed to DataLoader. + prefetch_factor=8, #: Passed to DataLoader. + n_samples_per_timestep=16, #: Passed to NowcastingDataset + n_training_batches_per_epoch=200, # Add pre-fetch factor! + n_validation_batches_per_epoch=200, + collate_fn=lambda x: x, + convert_to_numpy=False, #: Leave data as Pandas / Xarray for pre-preparing. + normalise_sat=False, + skip_n_train_batches=0, + skip_n_validation_batches=0, + train_validation_percentage_split=50, + pv_load_azimuth_and_elevation=False, + ) + + _LOG.info("prepare_data()") + data_module.prepare_data() + _LOG.info("setup()") + data_module.setup() + + data_generator = iter(data_module.train_dataset) + batch = next(data_generator) + + batch_xr = batch_to_dataset(batch=batch) + assert type(batch_xr) == xr.Dataset diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1f725fd0..4f62b031 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,5 +1,5 @@ import numpy as np -from nowcasting_dataset.dataset import NowcastingDataset, NetCDFDataset +from nowcasting_dataset.dataset.datasets import NowcastingDataset import nowcasting_dataset.time as nd_time import pytest diff --git a/tests/test_netcdf_dataset.py b/tests/test_netcdf_dataset.py index c820cd62..f9a1c452 100644 --- a/tests/test_netcdf_dataset.py +++ b/tests/test_netcdf_dataset.py @@ -1,6 +1,6 @@ import os import torch -from nowcasting_dataset.dataset import NetCDFDataset, worker_init_fn +from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn import plotly.graph_objects as go import plotly import pandas as pd @@ -11,7 +11,7 @@ @pytest.mark.skip("CD does not have access to GCS") def test_get_dataloaders_gcp(): - DATA_PATH = "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v4/" + DATA_PATH = "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v5/" TEMP_PATH = "../nowcasting_dataset" train_dataset = NetCDFDataset(24_900, os.path.join(DATA_PATH, "train"), os.path.join(TEMP_PATH, "train")) @@ -35,6 +35,7 @@ def test_get_dataloaders_gcp(): # image z = data["sat_data"][0][0][:, :, 0] + _ = data['gsp_yield'][0][:,0] _ = pd.to_datetime(data["sat_datetime_index"][0, 0], unit="s") diff --git a/tests/test_time.py b/tests/test_time.py index 368ee626..9b7c67e5 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -1,23 +1,22 @@ import pytest from nowcasting_dataset import time as nd_time +from nowcasting_dataset.time import THIRTY_MINUTES, FIVE_MINUTES import pandas as pd import numpy as np +from datetime import timedelta def test_select_daylight_datetimes(): datetimes = pd.date_range("2020-01-01 00:00", "2020-01-02 00:00", freq="H") locations = [(0, 0), (20_000, 20_000)] - daylight_datetimes = nd_time.select_daylight_datetimes( - datetimes=datetimes, locations=locations) - correct_daylight_datetimes = pd.date_range( - "2020-01-01 09:00", "2020-01-01 16:00", freq="H") - np.testing.assert_array_equal( - daylight_datetimes, correct_daylight_datetimes) + daylight_datetimes = nd_time.select_daylight_datetimes(datetimes=datetimes, locations=locations) + correct_daylight_datetimes = pd.date_range("2020-01-01 09:00", "2020-01-01 16:00", freq="H") + np.testing.assert_array_equal(daylight_datetimes, correct_daylight_datetimes) def test_intersection_of_datetimeindexes(): # Test with just one - index = pd.date_range('2010-01-01', '2010-01-02', freq='H') + index = pd.date_range("2010-01-01", "2010-01-02", freq="H") intersection = nd_time.intersection_of_datetimeindexes([index]) np.testing.assert_array_equal(index, intersection) @@ -26,61 +25,90 @@ def test_intersection_of_datetimeindexes(): np.testing.assert_array_equal(index, intersection) # Test with three with no intersection: - index2 = pd.date_range('2020-01-01', '2010-01-02', freq='H') + index2 = pd.date_range("2020-01-01", "2010-01-02", freq="H") intersection = nd_time.intersection_of_datetimeindexes([index, index2]) assert len(intersection) == 0 # Test with three, with some intersection: - index3 = pd.date_range('2010-01-01 06:00', '2010-01-02 06:00', freq='H') - index4 = pd.date_range('2010-01-01 12:00', '2010-01-02 12:00', freq='H') - intersection = nd_time.intersection_of_datetimeindexes( - [index, index3, index4]) - np.testing.assert_array_equal( - intersection, - pd.date_range('2010-01-01 12:00', '2010-01-02', freq='H')) - - -@pytest.mark.parametrize( - "total_seq_len", - [2, 3, 12] -) + index3 = pd.date_range("2010-01-01 06:00", "2010-01-02 06:00", freq="H") + index4 = pd.date_range("2010-01-01 12:00", "2010-01-02 12:00", freq="H") + intersection = nd_time.intersection_of_datetimeindexes([index, index3, index4]) + np.testing.assert_array_equal(intersection, pd.date_range("2010-01-01 12:00", "2010-01-02", freq="H")) + + +@pytest.mark.parametrize("total_seq_len", [2, 3, 12]) def test_get_start_datetimes_1(total_seq_len): - dt_index1 = pd.date_range('2010-01-01', '2010-01-02', freq='5 min') - start_datetimes = nd_time.get_start_datetimes( - dt_index1, total_seq_len=total_seq_len) - np.testing.assert_array_equal(start_datetimes, dt_index1[:1-total_seq_len]) + dt_index1 = pd.date_range("2010-01-01", "2010-01-02", freq="5 min") + start_datetimes = nd_time.get_start_datetimes(dt_index1, total_seq_len=total_seq_len) + np.testing.assert_array_equal(start_datetimes, dt_index1[: 1 - total_seq_len]) -@pytest.mark.parametrize( - "total_seq_len", - [2, 3, 12] -) +@pytest.mark.parametrize("total_seq_len", [2, 3, 12]) def test_get_start_datetimes_2(total_seq_len): - dt_index1 = pd.date_range('2010-01-01', '2010-01-02', freq='5 min') - dt_index2 = pd.date_range('2010-02-01', '2010-02-02', freq='5 min') + dt_index1 = pd.date_range("2010-01-01", "2010-01-02", freq="5 min") + dt_index2 = pd.date_range("2010-02-01", "2010-02-02", freq="5 min") dt_index = dt_index1.union(dt_index2) - start_datetimes = nd_time.get_start_datetimes( - dt_index, total_seq_len=total_seq_len) - correct_start_datetimes = dt_index1[:1-total_seq_len].union( - dt_index2[:1-total_seq_len]) + start_datetimes = nd_time.get_start_datetimes(dt_index, total_seq_len=total_seq_len) + correct_start_datetimes = dt_index1[: 1 - total_seq_len].union(dt_index2[: 1 - total_seq_len]) np.testing.assert_array_equal(start_datetimes, correct_start_datetimes) def test_timesteps_to_duration(): assert nd_time.timesteps_to_duration(0) == pd.Timedelta(0) - assert nd_time.timesteps_to_duration(1) == pd.Timedelta('5T') - assert nd_time.timesteps_to_duration(12) == pd.Timedelta('1H') + assert nd_time.timesteps_to_duration(1) == pd.Timedelta("5T") + assert nd_time.timesteps_to_duration(12) == pd.Timedelta("1H") def test_datetime_features_in_example(): - index = pd.date_range('2020-01-01', '2020-01-06 23:00', freq='h') + index = pd.date_range("2020-01-01", "2020-01-06 23:00", freq="h") example = nd_time.datetime_features_in_example(index) - assert len(example['hour_of_day_sin']) == len(index) - for col_name in ['hour_of_day_sin', 'hour_of_day_cos']: + assert len(example["hour_of_day_sin"]) == len(index) + for col_name in ["hour_of_day_sin", "hour_of_day_cos"]: assert col_name in example - np.testing.assert_array_almost_equal( - example[col_name], - np.tile(example[col_name][:24], reps=6)) + np.testing.assert_array_almost_equal(example[col_name], np.tile(example[col_name][:24], reps=6)) + + assert "day_of_year_sin" in example + assert "day_of_year_cos" in example + + +@pytest.mark.parametrize("history_length", [2, 3, 12]) +@pytest.mark.parametrize("forecast_length", [2, 3, 12]) +def test_get_t0_datetimes(history_length, forecast_length): + index = pd.date_range("2020-01-01", "2020-01-06 23:00", freq="30T") + total_seq_len = history_length + forecast_length + 1 + + t0_datetimes = nd_time.get_t0_datetimes(datetimes=index, total_seq_len=total_seq_len, history_len=history_length, + max_gap=THIRTY_MINUTES, minute_delta=30) + + assert len(t0_datetimes) == len(index) - history_length - forecast_length + assert t0_datetimes[0] == index[0] + timedelta(minutes=30 * history_length) + assert t0_datetimes[-1] == index[-1] - timedelta(minutes=30 * forecast_length) + + +def test_get_t0_datetimes_night(): + history_length = 6 + forecast_length = 12 + index = pd.date_range("2020-06-15", "2020-06-15 22:15", freq="5T") + total_seq_len = history_length + forecast_length + 1 + + t0_datetimes = nd_time.get_t0_datetimes(datetimes=index, total_seq_len=total_seq_len, + history_len=history_length, + max_gap=FIVE_MINUTES) + + assert len(t0_datetimes) == len(index) - history_length - forecast_length + assert t0_datetimes[0] == index[0] + timedelta(minutes=5 * history_length) + assert t0_datetimes[-1] == index[-1] - timedelta(minutes=5 * forecast_length) + + +def test_fill_30_minutes_timestamps_to_5_minutes(): + index = pd.date_range("2020-01-01", "2020-01-02", freq="30T") + + # remove >4.30 to <7.30 o'clock + index = index[0:10].join(index[15:], how="outer") + + index_5 = nd_time.fill_30_minutes_timestamps_to_5_minutes(index) + + assert len(index_5) == 24 * 12 + 1 - (3 * 12 - 1) + # 24*12 is total number of 5s in a day, +1 for the next day. + # 3*12 - 1 is the amount of 5 mins between 4.30 and 7.30 (not inclusive) - assert 'day_of_year_sin' in example - assert 'day_of_year_cos' in example diff --git a/tests/test_utils.py b/tests/test_utils.py index 89eee1d8..98aec0ed 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ from nowcasting_dataset import utils +from nowcasting_dataset.dataset.example import Example import pandas as pd import pytest import numpy as np @@ -32,3 +33,21 @@ def test_sin_and_cos(): def test_get_netcdf_filename(): assert utils.get_netcdf_filename(10) == '10.nc' assert utils.get_netcdf_filename(10, add_hash=True) == '77eb6f_10.nc' + + +def test_pad_data(): + seq_length = 4 + n_gsp_system_ids = 17 + + data = Example() + data['gsp_yield'] = np.random.random((seq_length, n_gsp_system_ids)) + data['gsp_system_id'] = np.random.random((n_gsp_system_ids)) + + data = utils.pad_data(data=data, + pad_size=1, + one_dimensional_arrays=['gsp_system_id'], + two_dimensional_arrays=['gsp_yield']) + + assert data['gsp_yield'].shape == (seq_length, n_gsp_system_ids+1) + assert data['gsp_system_id'].shape == (n_gsp_system_ids + 1,) +