Skip to content

Commit 09300d9

Browse files
Feat/Example notebook for the regression models (#2039)
* feat: example notebook for the regression models and new dataset (energy consumption and weather in Zurich, between 2015 and 2022) * fix: tests, some datasets width were missing * feat: udpated changelog * fix: to keep the API uniform, Zurich energy consumption and weather was split into two datasets. Energy consumption was added to the darts repo * fix: changed the way datasets are loaded, added an illustration for multi_models=True * fix: tweaked notebook * feat: grouped dataset and their width into a single variable to improve readibility * Apply suggestions from code review Co-authored-by: Dennis Bader <[email protected]> * fix: simplified API to load the EnergyConsumptionZurich dataset, updated notebook accordingly * fix: remove the obsolete dataset from the tests * blabla * update dataset * update notebook p1 * update regression model notebook * notebook last fixes * fix: typo * add regression model example test to merge workflow --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent da049e5 commit 09300d9

11 files changed

+1269
-36
lines changed

.github/workflows/merge.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ jobs:
8787
runs-on: ubuntu-latest
8888
strategy:
8989
matrix:
90-
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb, 15-static-covariates.ipynb, 16-hierarchical-reconciliation.ipynb, 18-TiDE-examples.ipynb, 19-EnsembleModel-examples.ipynb]
90+
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb, 15-static-covariates.ipynb, 16-hierarchical-reconciliation.ipynb, 18-TiDE-examples.ipynb, 19-EnsembleModel-examples.ipynb, 20-RegressionModel-examples.ipynb]
9191
steps:
9292
- name: "1. Clone repository"
9393
uses: actions/checkout@v2

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1616
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
1717
- Improvements to documentation:
1818
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
19+
- New example notebook for the `RegressionModels` explaining features such as (component-specific) lags, `output_chunk_length` in relation with `multi_models`, multivariate support, and more. [#2039](https://github.com/unit8co/darts/pull/2039) by [Antoine Madrona](https://github.com/madtoinou).
1920
- Improvements to Regression Models:
2021
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
2122
- Other improvements:
2223
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
2324
- Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou).
25+
- Added new dataset `ElectricityConsumptionZurichDataset`: The dataset contains the electricity consumption of households in Zurich, Switzerland from 2015-2022 on different grid levels. We also added weather measurements for Zurich which can be used as covariates for modelling. [#2039](https://github.com/unit8co/darts/pull/2039) by [Antoine Madrona](https://github.com/madtoinou) and [Dennis Bader](https://github.com/dennisbader).
2426

2527
**Fixed**
2628
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).

darts/datasets/__init__.py

+110-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
A few popular time series datasets
66
"""
77

8+
import os
89
from pathlib import Path
9-
from typing import List
10+
from typing import List, Literal, Optional
1011

1112
import numpy as np
1213
import pandas as pd
@@ -813,3 +814,111 @@ def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
813814
Load the WeatherDataset dataset as a list of univariate timeseries, one for weather indicator.
814815
"""
815816
return [TimeSeries.from_series(series[label]) for label in series]
817+
818+
819+
class ElectricityConsumptionZurichDataset(DatasetLoaderCSV):
820+
"""
821+
Electricity Consumption of households & SMEs (low voltage) and businesses & services (medium voltage) in the
822+
city of Zurich [1]_, with values recorded every 15 minutes.
823+
824+
The electricity consumption is combined with weather measurements recorded by three different
825+
stations in the city of Zurich with a hourly frequency [2]_. The missing time stamps are filled with NaN.
826+
The original weather data is recorded every hour. Before adding the features to the electricity consumption,
827+
the data is resampled to 15 minutes frequency, and missing values are interpolated.
828+
829+
To simplify the dataset, the measurements from the Zch_Schimmelstrasse and Zch_Rosengartenstrasse weather
830+
stations are discarded to keep only the data recorded in the Zch_Stampfenbachstrasse station.
831+
832+
Both dataset sources are updated continuously, but this dataset only retrains values between 2015 and 2022.
833+
The time index was converted from CET time zone to UTC.
834+
835+
Components Descriptions:
836+
837+
* Value_NE5 : Households & SMEs electricity consumption (low voltage, grid level 7) in kWh
838+
* Value_NE7 : Business and services electricity consumption (medium voltage, grid level 5) in kWh
839+
* Hr [%Hr] : Relative humidity
840+
* RainDur [min] : Duration of precipitation (divided by 4 for conversion from hourly to quarter-hourly records)
841+
* T [°C] : Temperature
842+
* WD [°] : Wind direction
843+
* WVv [m/s] : Wind vector speed
844+
* p [hPa] : Air pressure
845+
* WVs [m/s] : Wind scalar speed
846+
* StrGlo [W/m2] : Global solar irradiation
847+
848+
Note: before 2018, the scalar speeds were calculated from the 30 minutes vector data.
849+
850+
References
851+
----------
852+
.. [1] https://data.stadt-zuerich.ch/dataset/ewz_stromabgabe_netzebenen_stadt_zuerich
853+
.. [2] https://data.stadt-zuerich.ch/dataset/ugz_meteodaten_stundenmittelwerte
854+
"""
855+
856+
def __init__(self):
857+
def pre_process_dataset(dataset_path):
858+
"""Restrict the time axis and add the weather data"""
859+
df = pd.read_csv(dataset_path, index_col=0)
860+
# convert time index
861+
df.index = (
862+
pd.DatetimeIndex(df.index, tz="CET").tz_convert("UTC").tz_localize(None)
863+
)
864+
# extract pre-determined period
865+
df = df.loc[
866+
(pd.Timestamp("2015-01-01") <= df.index)
867+
& (df.index <= pd.Timestamp("2022-12-31"))
868+
]
869+
# download and preprocess the weather information
870+
df_weather = self._download_weather_data()
871+
# add weather data as additional features
872+
df = pd.concat([df, df_weather], axis=1)
873+
# interpolate weather data
874+
df = df.interpolate()
875+
# raining duration is given in minutes -> we divide by 4 from hourly to quarter-hourly records
876+
df["RainDur [min]"] = df["RainDur [min]"] / 4
877+
878+
# round Electricity cols to 4 decimals, other columns to 2 decimals
879+
cols_precise = ["Value_NE5", "Value_NE7"]
880+
df = df.round(
881+
decimals={col: (4 if col in cols_precise else 2) for col in df.columns}
882+
)
883+
884+
# export the dataset
885+
df.index.name = "Timestamp"
886+
df.to_csv(self._get_path_dataset())
887+
888+
# hash value for dataset with weather data
889+
super().__init__(
890+
metadata=DatasetLoaderMetadata(
891+
"zurich_electricity_consumption.csv",
892+
uri=(
893+
"https://data.stadt-zuerich.ch/dataset/"
894+
"ewz_stromabgabe_netzebenen_stadt_zuerich/"
895+
"download/ewz_stromabgabe_netzebenen_stadt_zuerich.csv"
896+
),
897+
hash="c2fea1a0974611ff1c276abcc1d34619",
898+
header_time="Timestamp",
899+
freq="15min",
900+
pre_process_csv_fn=pre_process_dataset,
901+
)
902+
)
903+
904+
@staticmethod
905+
def _download_weather_data():
906+
"""Concatenate the yearly csv files into a single dataframe and reshape it"""
907+
# download the csv from the url
908+
base_url = "https://data.stadt-zuerich.ch/dataset/ugz_meteodaten_stundenmittelwerte/download/"
909+
filenames = [f"ugz_ogd_meteo_h1_{year}.csv" for year in range(2015, 2023)]
910+
df = pd.concat([pd.read_csv(base_url + fname) for fname in filenames])
911+
# retain only one weather station
912+
df = df.loc[df["Standort"] == "Zch_Stampfenbachstrasse"]
913+
# pivot the df to get all measurements as columns
914+
df["param_name"] = df["Parameter"] + " [" + df["Einheit"] + "]"
915+
df = df.pivot(index="Datum", columns="param_name", values="Wert")
916+
# convert time index to from CET to UTC and extract the required time range
917+
df.index = (
918+
pd.DatetimeIndex(df.index, tz="CET").tz_convert("UTC").tz_localize(None)
919+
)
920+
df = df.loc[
921+
(pd.Timestamp("2015-01-01") <= df.index)
922+
& (df.index <= pd.Timestamp("2022-12-31"))
923+
]
924+
return df

darts/datasets/dataset_loaders.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class DatasetLoaderMetadata:
3131
format_time: Optional[str] = None
3232
# used to indicate the freq when we already know it
3333
freq: Optional[str] = None
34-
# a custom function to handling non-csv based datasets
34+
# a custom function handling non-csv based datasets
3535
pre_process_zipped_csv_fn: Optional[Callable] = None
36+
# a custom function handling csv based datasets
37+
pre_process_csv_fn: Optional[Callable] = None
3638
# multivariate
3739
multivariate: Optional[bool] = None
3840

@@ -49,7 +51,9 @@ class DatasetLoader(ABC):
4951

5052
_DEFAULT_DIRECTORY = Path(os.path.join(Path.home(), Path(".darts/datasets/")))
5153

52-
def __init__(self, metadata: DatasetLoaderMetadata, root_path: Path = None):
54+
def __init__(
55+
self, metadata: DatasetLoaderMetadata, root_path: Optional[Path] = None
56+
):
5357
self._metadata: DatasetLoaderMetadata = metadata
5458
if root_path is None:
5559
self._root_path: Path = DatasetLoader._DEFAULT_DIRECTORY
@@ -131,7 +135,13 @@ def _download_dataset(self):
131135
"Could not download the dataset. Reason:" + e.__repr__()
132136
) from None
133137

138+
if self._metadata.pre_process_csv_fn is not None:
139+
self._metadata.pre_process_csv_fn(self._get_path_dataset())
140+
134141
def _download_zip_dataset(self):
142+
if self._metadata.pre_process_csv_fn:
143+
logger.warning("Loading a ZIP file does not use the pre_process_csv_fn")
144+
135145
os.makedirs(self._root_path, exist_ok=True)
136146
try:
137147
request = requests.get(self._metadata.uri)
@@ -186,7 +196,9 @@ def _format_time_column(self, df):
186196

187197

188198
class DatasetLoaderCSV(DatasetLoader):
189-
def __init__(self, metadata: DatasetLoaderMetadata, root_path: Path = None):
199+
def __init__(
200+
self, metadata: DatasetLoaderMetadata, root_path: Optional[Path] = None
201+
):
190202
super().__init__(metadata, root_path)
191203

192204
def _load_from_disk(

darts/tests/datasets/test_dataset_loaders.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AirPassengersDataset,
1111
AusBeerDataset,
1212
AustralianTourismDataset,
13+
ElectricityConsumptionZurichDataset,
1314
ElectricityDataset,
1415
EnergyDataset,
1516
ETTh1Dataset,
@@ -40,37 +41,36 @@
4041
DatasetLoadingException,
4142
)
4243

43-
datasets = [
44-
AirPassengersDataset,
45-
AusBeerDataset,
46-
AustralianTourismDataset,
47-
EnergyDataset,
48-
HeartRateDataset,
49-
IceCreamHeaterDataset,
50-
MonthlyMilkDataset,
51-
SunspotsDataset,
52-
TaylorDataset,
53-
TemperatureDataset,
54-
USGasolineDataset,
55-
WineDataset,
56-
WoolyDataset,
57-
GasRateCO2Dataset,
58-
MonthlyMilkIncompleteDataset,
59-
ETTh1Dataset,
60-
ETTh2Dataset,
61-
ETTm1Dataset,
62-
ETTm2Dataset,
63-
ElectricityDataset,
64-
UberTLCDataset,
65-
ILINetDataset,
66-
ExchangeRateDataset,
67-
TrafficDataset,
68-
WeatherDataset,
69-
]
70-
7144
_DEFAULT_PATH_TEST = _DEFAULT_PATH + "/tests"
7245

73-
width_datasets = [1, 1, 96, 28, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 7, 7, 7, 7, 370, 262]
46+
datasets_with_width = [
47+
(AirPassengersDataset, 1),
48+
(AusBeerDataset, 1),
49+
(AustralianTourismDataset, 96),
50+
(EnergyDataset, 28),
51+
(HeartRateDataset, 1),
52+
(IceCreamHeaterDataset, 2),
53+
(MonthlyMilkDataset, 1),
54+
(SunspotsDataset, 1),
55+
(TaylorDataset, 1),
56+
(TemperatureDataset, 1),
57+
(USGasolineDataset, 1),
58+
(WineDataset, 1),
59+
(WoolyDataset, 1),
60+
(GasRateCO2Dataset, 2),
61+
(MonthlyMilkIncompleteDataset, 1),
62+
(ETTh1Dataset, 7),
63+
(ETTh2Dataset, 7),
64+
(ETTm1Dataset, 7),
65+
(ETTm2Dataset, 7),
66+
(ElectricityDataset, 370),
67+
(UberTLCDataset, 262),
68+
(ILINetDataset, 11),
69+
(ExchangeRateDataset, 8),
70+
(TrafficDataset, 862),
71+
(WeatherDataset, 21),
72+
(ElectricityConsumptionZurichDataset, 10),
73+
]
7474

7575
wrong_hash_dataset = DatasetLoaderCSV(
7676
metadata=DatasetLoaderMetadata(
@@ -135,9 +135,9 @@ def tmp_dir_dataset():
135135

136136
class TestDatasetLoader:
137137
@pytest.mark.slow
138-
@pytest.mark.parametrize("dataset_config", zip(width_datasets, datasets))
138+
@pytest.mark.parametrize("dataset_config", datasets_with_width)
139139
def test_ok_dataset(self, dataset_config, tmp_dir_dataset):
140-
width, dataset_cls = dataset_config
140+
dataset_cls, width = dataset_config
141141
dataset = dataset_cls()
142142
assert dataset._DEFAULT_DIRECTORY == tmp_dir_dataset
143143
ts: TimeSeries = dataset.load()

docs/source/examples.rst

+10
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ with Darts using the Optuna library for hyperparameter optimization.
7676

7777
examples/17-hyperparameter-optimization.ipynb
7878

79+
Regression Models
80+
=================
81+
82+
Regression models example notebook:
83+
84+
.. toctree::
85+
:maxdepth: 1
86+
87+
examples/20-RegressionModel-examples.ipynb
88+
7989

8090
Fast Fourier Transform
8191
======================

examples/20-RegressionModel-examples.ipynb

+1,100
Large diffs are not rendered by default.
54.8 KB
Loading
Loading
57.2 KB
Loading
48.7 KB
Loading

0 commit comments

Comments
 (0)