Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Issue/106 refactor dataset #108

Merged
merged 12 commits into from
Sep 10, 2021
16 changes: 16 additions & 0 deletions nowcasting_dataset/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,19 @@

# Typing
Array = Union[xr.DataArray, np.ndarray]
PV_SYSTEM_ID: str = 'pv_system_id'
PV_SYSTEM_ROW_NUMBER = 'pv_system_row_number'
PV_SYSTEM_X_COORDS = 'pv_system_x_coords'
PV_SYSTEM_Y_COORDS = 'pv_system_y_coords'
PV_AZIMUTH_ANGLE = 'pv_azimuth_angle'
PV_ELEVATION_ANGLE = 'pv_elevation_angle'
PV_YIELD = 'pv_yield'
DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE = 128
GSP_SYSTEM_ID: str = "gsp_system_id"
GSP_YIELD = "gsp_yield"
GSP_SYSTEM_X_COORDS = "gsp_system_x_coords"
GSP_SYSTEM_Y_COORDS = "gsp_system_y_coords"
GSP_DATETIME_INDEX = "gsp_datetime_index"
DEFAULT_N_GSP_PER_EXAMPLE = 32
CENTROID_TYPE = "centroid_type"
DATETIME_FEATURE_NAMES = ("hour_of_day_sin", "hour_of_day_cos", "day_of_year_sin", "day_of_year_cos")
24 changes: 0 additions & 24 deletions nowcasting_dataset/data_sources/constants.py

This file was deleted.

6 changes: 3 additions & 3 deletions nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from numbers import Number
import pandas as pd
import numpy as np
from nowcasting_dataset.example import Example, to_numpy
from nowcasting_dataset.dataset.example import Example, to_numpy
from nowcasting_dataset import square
import nowcasting_dataset.time as nd_time
from dataclasses import dataclass, InitVar
from typing import List, Tuple, Iterable
import xarray as xr
import itertools
import logging
from typing import Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,8 +198,9 @@ def get_example(
f'x_meters_center={x_meters_center}\n'
f'y_meters_center={y_meters_center}\n'
f't0_dt={t0_dt}\n'
f'times are {selected_data.time}\n'
f'expected shape={self._shape_of_example}\n'
f'actual shape {selected_data.shape}')
f'actual shape {selected_data.shape}')

return self._put_data_into_example(selected_data)

Expand Down
2 changes: 1 addition & 1 deletion nowcasting_dataset/data_sources/datetime_data_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nowcasting_dataset.data_sources.data_source import DataSource
from nowcasting_dataset.example import Example
from nowcasting_dataset.dataset.example import Example
from nowcasting_dataset import time as nd_time
from dataclasses import dataclass
import pandas as pd
Expand Down
7 changes: 3 additions & 4 deletions nowcasting_dataset/data_sources/gsp/gsp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
from nowcasting_dataset.utils import scale_to_0_to_1, pad_data
from nowcasting_dataset.square import get_bounding_box_mask
from nowcasting_dataset.geospatial import lat_lon_to_osgb
from nowcasting_dataset.example import Example
from nowcasting_dataset.dataset.example import Example
from nowcasting_dataset.data_sources.data_source import ImageDataSource
from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso

from nowcasting_dataset.data_sources.constants import GSP_YIELD, GSP_SYSTEM_ID, GSP_SYSTEM_X_COORDS, \
GSP_SYSTEM_Y_COORDS, DEFAULT_N_GSP_PER_EXAMPLE, CENTROID_TYPE

from nowcasting_dataset.consts import GSP_SYSTEM_ID, GSP_YIELD, GSP_SYSTEM_X_COORDS, GSP_SYSTEM_Y_COORDS, \
DEFAULT_N_GSP_PER_EXAMPLE, CENTROID_TYPE

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion nowcasting_dataset/data_sources/nwp_data_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nowcasting_dataset.data_sources.data_source import ZarrDataSource
from nowcasting_dataset.example import Example, to_numpy
from nowcasting_dataset.dataset.example import Example, to_numpy
from nowcasting_dataset import utils
from typing import Iterable, Optional, List
import xarray as xr
Expand Down
9 changes: 5 additions & 4 deletions nowcasting_dataset/data_sources/pv_data_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from nowcasting_dataset.data_sources.constants import PV_SYSTEM_ID, PV_SYSTEM_ROW_NUMBER, PV_SYSTEM_X_COORDS, \
PV_SYSTEM_Y_COORDS, PV_AZIMUTH_ANGLE, PV_ELEVATION_ANGLE, PV_YIELD, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE
from nowcasting_dataset.consts import PV_SYSTEM_ID, PV_SYSTEM_ROW_NUMBER, PV_SYSTEM_X_COORDS, PV_SYSTEM_Y_COORDS, \
PV_AZIMUTH_ANGLE, PV_ELEVATION_ANGLE, PV_YIELD, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE
from nowcasting_dataset.data_sources.data_source import ImageDataSource
from nowcasting_dataset.example import Example
from nowcasting_dataset.dataset.example import Example
from nowcasting_dataset import geospatial, utils
from nowcasting_dataset.square import get_bounding_box_mask
from dataclasses import dataclass
Expand Down Expand Up @@ -225,7 +225,8 @@ def get_example(
x_meters_center=x_meters_center,
y_meters_center=y_meters_center,
pv_system_x_coords=pv_system_x_coords,
pv_system_y_coords=pv_system_y_coords)
pv_system_y_coords=pv_system_y_coords,
pv_datetime_index=selected_pv_power.index)

if self.load_azimuth_and_elevation:
example[PV_AZIMUTH_ANGLE] = selected_pv_azimuth_angle
Expand Down
2 changes: 1 addition & 1 deletion nowcasting_dataset/data_sources/satellite_data_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nowcasting_dataset.data_sources.data_source import ZarrDataSource
from nowcasting_dataset.example import Example, to_numpy
from nowcasting_dataset.dataset.example import Example, to_numpy
from nowcasting_dataset import utils
from typing import Iterable, Optional, List
from numbers import Number
Expand Down
30 changes: 30 additions & 0 deletions nowcasting_dataset/dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Datasets

This folder contains the following files

## batch.py

Functions used to 'play with' batch data i.e. a List[Example]

## datamodule.py

Contains a class NowcastingDataModule - pl.LightningDataModule
This handles the
- amalgamation of all differene data sources,
- making valid datetimes across all the sources,
- splitting into train and validation datasets


## datasets.py

This file contains the following classes

NetCDFDatase- torch.utils.data.Dataset: Use for loading pre-made batches
NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches
ContiguousNowcastingDataset - NowcastingDataset

## example.py

Main thing in here is a Typed Dictionary. This is used to store one element of data use for one step in the ML models.
There is also a validation function.

155 changes: 155 additions & 0 deletions nowcasting_dataset/dataset/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import List, Optional
import logging

import numpy as np
import xarray as xr
from pathlib import Path

from nowcasting_dataset.consts import GSP_SYSTEM_ID, GSP_YIELD, GSP_SYSTEM_X_COORDS, GSP_SYSTEM_Y_COORDS, \
DATETIME_FEATURE_NAMES

from nowcasting_dataset.dataset.example import Example
from nowcasting_dataset.utils import get_netcdf_filename

_LOG = logging.getLogger(__name__)

LOCAL_TEMP_PATH = Path('~/temp/').expanduser()


def write_batch_locally(batch: List[Example], batch_i: int):
"""
Write a batch to a locally file
Args:
batch: batch of data
batch_i: the number of the batch

"""
dataset = batch_to_dataset(batch)
dataset = fix_dtypes(dataset)
encoding = {name: {"compression": "lzf"} for name in dataset.data_vars}
filename = get_netcdf_filename(batch_i)
local_filename = LOCAL_TEMP_PATH / filename
dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)


def fix_dtypes(concat_ds):
"""
TODO
"""
ds_dtypes = {
"example": np.int32,
"sat_x_coords": np.int32,
"sat_y_coords": np.int32,
"nwp": np.float32,
"nwp_x_coords": np.float32,
"nwp_y_coords": np.float32,
"pv_system_id": np.float32,
"pv_system_row_number": np.float32,
"pv_system_x_coords": np.float32,
"pv_system_y_coords": np.float32,
}

for name, dtype in ds_dtypes.items():
concat_ds[name] = concat_ds[name].astype(dtype)

assert concat_ds["sat_data"].dtype == np.int16
return concat_ds


def batch_to_dataset(batch: List[Example]) -> xr.Dataset:
"""Concat all the individual fields in an Example into a single Dataset.

Args:
batch: List of Example objects, which together constitute a single batch.
"""
datasets = []
for i, example in enumerate(batch):
try:
individual_datasets = []
example_dim = {"example": np.array([i], dtype=np.int32)}
for name in ["sat_data", "nwp"]:
ds = example[name].to_dataset(name=name)
short_name = name.replace("_data", "")
if name == "nwp":
ds = ds.rename({"target_time": "time"})
for dim in ["time", "x", "y"]:
ds = coord_to_range(ds, dim, prefix=short_name)
ds = ds.rename(
{
"variable": f"{short_name}_variable",
"x": f"{short_name}_x",
"y": f"{short_name}_y",
}
)
individual_datasets.append(ds)

# Datetime features
for name in DATETIME_FEATURE_NAMES:
ds = example[name].rename(name).to_xarray().to_dataset().rename({"index": "time"})
ds = coord_to_range(ds, "time", prefix=None)
individual_datasets.append(ds)

# PV
one_dateset = xr.DataArray(example["pv_yield"], dims=["time", "pv_system"])
one_dateset = one_dateset.to_dataset(name="pv_yield")
n_pv_systems = len(example["pv_system_id"])

# GSP
n_gsp_systems = len(example[GSP_SYSTEM_ID])
one_dateset['gsp_yield'] = xr.DataArray(example[GSP_YIELD], dims=["time_30", "gsp_system"])

# This will expand all dataarrays to have an 'example' dim.
# 0D
for name in ["x_meters_center", "y_meters_center"]:
try:
one_dateset[name] = xr.DataArray([example[name]], coords=example_dim, dims=["example"])
except Exception as e:
_LOG.error(f'Could not make pv_yield data for {name} with example_dim={example_dim}')
if name not in example.keys():
_LOG.error(f'{name} not in data keys: {example.keys()}')
_LOG.error(e)
raise Exception

# 1D
for name in ["pv_system_id", "pv_system_row_number", "pv_system_x_coords", "pv_system_y_coords"]:
one_dateset[name] = xr.DataArray(
example[name][None, :],
coords={**example_dim, **{"pv_system": np.arange(n_pv_systems, dtype=np.int32)}},
dims=["example", "pv_system"],
)

# GSP
for name in [GSP_SYSTEM_ID, GSP_SYSTEM_X_COORDS, GSP_SYSTEM_Y_COORDS]:
try:
one_dateset[name] = xr.DataArray(
example[name][None, :],
coords={**example_dim, **{"gsp_system": np.arange(n_gsp_systems, dtype=np.int32)}},
dims=["example", "gsp_system"],
)
except Exception as e:
_LOG.debug(f'Could not add {name} to dataset. {example[name].shape}')
_LOG.error(e)
raise e

individual_datasets.append(one_dateset)

# Merge
merged_ds = xr.merge(individual_datasets)
datasets.append(merged_ds)

except Exception as e:
print(e)
_LOG.error(e)
raise Exception

return xr.concat(datasets, dim="example")


def coord_to_range(da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32) -> xr.DataArray:
# TODO: Actually, I think this is over-complicated? I think we can
# just strip off the 'coord' from the dimension.
coord = da[dim]
da[dim] = np.arange(len(coord), dtype=dtype)
if prefix is not None:
da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim])
return da
Loading