Skip to content

Commit 979491b

Browse files
FEJTWOWKamil Wierciakhrzn
authored
Add new datasets to darts.dataset (#1298)
* Add new datasets to darts.dataset * Add new dataset to dataset catalog * Specify the class in darts.datasets.__init__.py to correctly load new dataset * Change uri (just for testing) and set date as index of df * Change .csv file and adjust hash * Update .csv file so the start day is on Sunday not Monday * Add correct hash * Change a bit dataset and update hash * Remove duplicates from dataset, adjust hash * Add original dataset (starting from 1997) and adjust hash * Add default_path instead of uri from branch * Add two new datasets, modify __init__ file * Add option to load datasets as a list of univariate time series * Add weather dataset * Fix datetime format, remove duplicates from dataset and adjust md5 hash * Fix freq * fix after comments * fix ILI description * Add source * Fix exchange rate dataset * Refactor load from disk function so it can change multivariate indexed TS into list of univariate TS * Add git LFS to traffic.csv dataset * Remove unnecessary if * remove debug print * Fix docu Co-authored-by: Kamil Wierciak <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
1 parent d466815 commit 979491b

File tree

7 files changed

+79346
-48
lines changed

7 files changed

+79346
-48
lines changed

darts/datasets/__init__.py

+201-36
Original file line numberDiff line numberDiff line change
@@ -354,17 +354,18 @@ class ETTh1Dataset(DatasetLoaderCSV):
354354
"""
355355
The data of 1 Electricity Transformers at 1 stations, including load, oil temperature.
356356
The dataset ranges from 2016/07 to 2018/07 taken hourly.
357-
Source: [1][2]_
357+
Source: [1]_ [2]_
358358
359359
Field Descriptions:
360-
date: The recorded date
361-
HUFL: High UseFul Load
362-
HULL: High UseLess Load
363-
MUFL: Medium UseFul Load
364-
MULL: Medium UseLess Load
365-
LUFL: Low UseFul Load
366-
LULL: Low UseLess Load
367-
OT: Oil Temperature (Target)
360+
361+
* date: The recorded date
362+
* HUFL: High UseFul Load
363+
* HULL: High UseLess Load
364+
* MUFL: Medium UseFul Load
365+
* MULL: Medium UseLess Load
366+
* LUFL: Low UseFul Load
367+
* LULL: Low UseLess Load
368+
* OT: Oil Temperature (Target)
368369
369370
References
370371
----------
@@ -388,17 +389,18 @@ class ETTh2Dataset(DatasetLoaderCSV):
388389
"""
389390
The data of 1 Electricity Transformers at 1 stations, including load, oil temperature.
390391
The dataset ranges from 2016/07 to 2018/07 taken hourly.
391-
Source: [1][2]_
392+
Source: [1]_ [2]_
392393
393394
Field Descriptions:
394-
date: The recorded date
395-
HUFL: High UseFul Load
396-
HULL: High UseLess Load
397-
MUFL: Medium UseFul Load
398-
MULL: Medium UseLess Load
399-
LUFL: Low UseFul Load
400-
LULL: Low UseLess Load
401-
OT: Oil Temperature (Target)
395+
396+
* date: The recorded date
397+
* HUFL: High UseFul Load
398+
* HULL: High UseLess Load
399+
* MUFL: Medium UseFul Load
400+
* MULL: Medium UseLess Load
401+
* LUFL: Low UseFul Load
402+
* LULL: Low UseLess Load
403+
* OT: Oil Temperature (Target)
402404
403405
References
404406
----------
@@ -422,17 +424,18 @@ class ETTm1Dataset(DatasetLoaderCSV):
422424
"""
423425
The data of 1 Electricity Transformers at 1 stations, including load, oil temperature.
424426
The dataset ranges from 2016/07 to 2018/07 recorded every 15 minutes.
425-
Source: [1][2]_
427+
Source: [1]_ [2]_
426428
427429
Field Descriptions:
428-
date: The recorded date
429-
HUFL: High UseFul Load
430-
HULL: High UseLess Load
431-
MUFL: Medium UseFul Load
432-
MULL: Medium UseLess Load
433-
LUFL: Low UseFul Load
434-
LULL: Low UseLess Load
435-
OT: Oil Temperature (Target)
430+
431+
* date: The recorded date
432+
* HUFL: High UseFul Load
433+
* HULL: High UseLess Load
434+
* MUFL: Medium UseFul Load
435+
* MULL: Medium UseLess Load
436+
* LUFL: Low UseFul Load
437+
* LULL: Low UseLess Load
438+
* OT: Oil Temperature (Target)
436439
437440
References
438441
----------
@@ -456,17 +459,18 @@ class ETTm2Dataset(DatasetLoaderCSV):
456459
"""
457460
The data of 1 Electricity Transformers at 1 stations, including load, oil temperature.
458461
The dataset ranges from 2016/07 to 2018/07 recorded every 15 minutes.
459-
Source: [1][2]_
462+
Source: [1]_ [2]_
460463
461464
Field Descriptions:
462-
date: The recorded date
463-
HUFL: High UseFul Load
464-
HULL: High UseLess Load
465-
MUFL: Medium UseFul Load
466-
MULL: Medium UseLess Load
467-
LUFL: Low UseFul Load
468-
LULL: Low UseLess Load
469-
OT: Oil Temperature (Target)
465+
466+
* date: The recorded date
467+
* HUFL: High UseFul Load
468+
* HULL: High UseLess Load
469+
* MUFL: Medium UseFul Load
470+
* MULL: Medium UseLess Load
471+
* LUFL: Low UseFul Load
472+
* LULL: Low UseLess Load
473+
* OT: Oil Temperature (Target)
470474
471475
References
472476
----------
@@ -648,3 +652,164 @@ def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
648652
ts = TimeSeries.from_dataframe(tmp, "date", ["locationID"])
649653
ts_list.append(ts)
650654
return ts_list
655+
656+
657+
class ILINetDataset(DatasetLoaderCSV):
658+
"""
659+
ILI describes the number of patients seen with influenzalike illness and the total number of patients. It includes
660+
weekly data from the Centers for Disease Control and Prevention of the United States from 1997 to 2022.
661+
Source: [1]_ [2]_ [3]_ [4]_
662+
663+
Components Descriptions:
664+
665+
* % WEIGHTED ILI: Combined state-specific data of patients visit to healthcare providers for ILI reported each week weighted by state population
666+
* % UNWEIGHTED ILI: Combined state-specific data of patients visit to healthcare providers for ILI reported each week unweighted by state population
667+
* AGE 0-4: Number of patients between 0 and 4 years of age
668+
* AGE 25-49: Number of patients between 25 and 49 years of age
669+
* AGE 25-64: Number of patients between 25 and 64 years of age
670+
* AGE 5-24: Number of patients between 5 and 24 years of age
671+
* AGE 50-64: Number of patients between 50 and 64 years of age
672+
* AGE 65: Number of patients above (>=65) 65 years of age
673+
* ILITOTAL: Total number of ILI patients. For this system, ILI is defined as fever (temperature of 100°F [37.8°C] or greater) and a cough and/or a sore throat
674+
* NUM. OF PROVIDERS: Number of outpatient healthcare providers
675+
* TOTAL PATIENTS: Total number of patients
676+
677+
678+
679+
References
680+
----------
681+
.. [1] https://gis.cdc.gov/grasp/fluview/fluportaldashboard.html
682+
.. [2] https://www.cdc.gov/flu/weekly/overview.htm#Outpatient
683+
.. [3] https://arxiv.org/pdf/2205.13504.pdf
684+
.. [4] https://gis.cdc.gov/grasp/fluview/FluViewPhase2QuickReferenceGuide.pdf
685+
"""
686+
687+
def __init__(self, multivariate: bool = True):
688+
super().__init__(
689+
metadata=DatasetLoaderMetadata(
690+
"ILINet.csv",
691+
uri=_DEFAULT_PATH + "/ILINet.csv",
692+
hash="c9cbd6cc0a92b21cd95bec2706212d8d",
693+
header_time="DATE",
694+
format_time="%Y-%m-%d",
695+
freq="W",
696+
multivariate=multivariate,
697+
)
698+
)
699+
700+
def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
701+
"""
702+
Load the ILINetDataset dataset as a list of univariate timeseries.
703+
"""
704+
return [TimeSeries.from_series(series[label]) for label in series]
705+
706+
707+
class ExchangeRateDataset(DatasetLoaderCSV):
708+
"""
709+
The collection of the daily exchange rates of eight foreign countries, including Australia, British, Canada, Switzerland, China, Japan, New Zealand,
710+
and Singapore, ranging from 1990 to 2016. Unfortunately, there were some inconsistencies concerning the dates, so the resulting TimeSeries is integer-indexed.
711+
Source: [1]_
712+
713+
References
714+
----------
715+
.. [1] https://github.com/laiguokun/multivariate-time-series-data
716+
"""
717+
718+
def __init__(self, multivariate: bool = True):
719+
"""
720+
Parameters
721+
----------
722+
multivariate: bool
723+
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
724+
"""
725+
super().__init__(
726+
metadata=DatasetLoaderMetadata(
727+
"exchange_rate.csv",
728+
uri=_DEFAULT_PATH + "/exchange_rate.csv",
729+
hash="6e35621a9eb6a9dd5465cf52a22b1339",
730+
header_time=None,
731+
multivariate=multivariate,
732+
)
733+
)
734+
735+
def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
736+
"""
737+
Load the ExchangeRateDataset dataset as a list of univariate timeseries, one for each country.
738+
"""
739+
return [TimeSeries.from_series(series[label]) for label in series]
740+
741+
742+
class TrafficDataset(DatasetLoaderCSV):
743+
"""
744+
The data in this repo is a collection of 48 months (2015-2016) hourly data from the California Department of Transportation. The data describes
745+
the road occupancy rates (between 0 and 1) measured by 862 different sensors on San Francisco Bay area freeways. The raw data is in http://pems.dot.ca.gov.
746+
Source: [1]_
747+
748+
References
749+
----------
750+
.. [1] https://github.com/laiguokun/multivariate-time-series-data
751+
"""
752+
753+
def __init__(self, multivariate: bool = True):
754+
"""
755+
Parameters
756+
----------
757+
multivariate: bool
758+
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
759+
"""
760+
super().__init__(
761+
metadata=DatasetLoaderMetadata(
762+
"traffic.csv",
763+
uri=_DEFAULT_PATH + "/traffic.csv",
764+
hash="a2105f364ef70aec06c757304833f72a",
765+
header_time="Date",
766+
format_time="%Y-%m-%d %H:%M:%S",
767+
freq="1H",
768+
multivariate=multivariate,
769+
)
770+
)
771+
772+
def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
773+
"""
774+
Load the TrafficDataset dataset as a list of univariate timeseries, one for each ID.
775+
"""
776+
return [TimeSeries.from_series(series[label]) for label in series]
777+
778+
779+
class WeatherDataset(DatasetLoaderCSV):
780+
"""
781+
Weather includes 21 indicators of weather, such as air
782+
temperature, and humidity. The data was recorded every
783+
10 min for 2020 in Germany.
784+
Source: [1]_ [2]_
785+
786+
References
787+
----------
788+
.. [1] https://www.bgc-jena.mpg.de/wetter/
789+
.. [2] https://arxiv.org/pdf/2205.13504.pdf
790+
"""
791+
792+
def __init__(self, multivariate: bool = True):
793+
"""
794+
Parameters
795+
----------
796+
multivariate: bool
797+
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
798+
"""
799+
super().__init__(
800+
metadata=DatasetLoaderMetadata(
801+
"weather.csv",
802+
uri=_DEFAULT_PATH + "/weather.csv",
803+
hash="a2942a05638ba311bc7935bcc087a30f",
804+
header_time="Date Time",
805+
format_time="%d.%m.%Y %H:%M:%S",
806+
freq="10min",
807+
multivariate=multivariate,
808+
)
809+
)
810+
811+
def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
812+
"""
813+
Load the WeatherDataset dataset as a list of univariate timeseries, one for weather indicator.
814+
"""
815+
return [TimeSeries.from_series(series[label]) for label in series]

darts/datasets/dataset_loaders.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -194,24 +194,24 @@ def _load_from_disk(
194194
) -> Union[TimeSeries, List[TimeSeries]]:
195195

196196
df = pd.read_csv(path_to_file)
197-
198197
if metadata.header_time is not None:
199198
df = self._format_time_column(df)
200199
series = TimeSeries.from_dataframe(
201200
df=df, time_col=metadata.header_time, freq=metadata.freq
202201
)
203-
if (
204-
self._metadata.multivariate is not None
205-
and self._metadata.multivariate is False
206-
):
207-
try:
208-
series = self._to_multi_series(series.pd_dataframe())
209-
except Exception as e:
210-
raise DatasetLoadingException(
211-
"Could not convert to multi-series. Reason:" + e.__repr__()
212-
) from None
213202
else:
214203
df.sort_index(inplace=True)
215-
216204
series = TimeSeries.from_dataframe(df)
205+
206+
if (
207+
self._metadata.multivariate is not None
208+
and self._metadata.multivariate is False
209+
):
210+
try:
211+
series = self._to_multi_series(series.pd_dataframe())
212+
except Exception as e:
213+
raise DatasetLoadingException(
214+
"Could not convert to multi-series. Reason:" + e.__repr__()
215+
) from None
216+
217217
return series

datasets/.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
traffic.csv filter=lfs diff=lfs merge=lfs -text

0 commit comments

Comments
 (0)