Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new datasets to darts.dataset #1298

Merged
merged 25 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions darts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,168 @@ def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
ts = TimeSeries.from_dataframe(tmp, "date", ["locationID"])
ts_list.append(ts)
return ts_list


class ILINetDataset(DatasetLoaderCSV):
"""
ILI describes the ratio of patients seen with influenzalike illness and the number of patients. It includes
weekly data from the Centers for Disease Control and Prevention of the United States from 1997 to 2022

Field Descriptions:
DATE: The recorded date
% WEIGHTED ILI: Combined state-specific data of patients visit to healthcare providers for ILI reported each week weighted by state population
% UNWEIGHTED ILI: Combined state-specific data of patients visit to healthcare providers for ILI reported each week unweighted by state population
AGE 0-4: Number of patients between 0 and 4 years of age
AGE 25-49: Number of patients between 25 and 49 years of age
AGE 25-64: Number of patients between 25 and 64 years of age
AGE 5-24: Number of patients between 5 and 24 years of age
AGE 50-64: Number of patients between 50 and 64 years of age
AGE 65: Number of patients above (>=65) 65 years of age
ILITOTAL: Total number of ILI patients. For this system, ILI is defined as fever (temperature of 100°F [37.8°C] or greater) and a cough and/or a sore throat
NUM. OF PROVIDERS: Number of outpatient healthcare providers
TOTAL PATIENTS: Total number of patients

References
----------
.. [1] https://gis.cdc.gov/grasp/fluview/fluportaldashboard.html
.. [2] https://www.cdc.gov/flu/weekly/overview.htm#Outpatient
.. [3] https://arxiv.org/pdf/2205.13504.pdf
"""

def __init__(self):
super().__init__(
metadata=DatasetLoaderMetadata(
"ILINet.csv",
uri="https://raw.githubusercontent.com/unit8co/darts/Improvement/Add_new_datasets_617/datasets/ILINet.csv",
hash="c9cbd6cc0a92b21cd95bec2706212d8d",
header_time="DATE",
format_time="%Y-%m-%d",
freq="W",
)
)


class ExchangeRateDataset(DatasetLoaderCSV):
"""
The collection of the daily exchange rates of eight foreign countries including Australia, British, Canada,
Switzerland, China, Japan, New Zealand and Singapore ranging from 1990 to 2010.

References
----------
.. [1] https://github.com/laiguokun/multivariate-time-series-data
"""

def __init__(self, multivariate: bool = True):
"""
Parameters
----------
multivariate: bool
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
"""
super().__init__(
metadata=DatasetLoaderMetadata(
"exchange_rate.csv",
uri="https://raw.githubusercontent.com/unit8co/darts/Improvement/Add_new_datasets_617/datasets/exchange_rate.csv",
hash="9219e9a03eb54c6e40d7eb1c9b3b6f7c",
header_time="Date",
format_time="%Y-%m-%d",
freq="D",
multivariate=multivariate,
)
)

def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
"""
Load the ExchangeRateDataset dataset as a list of univariate timeseries, one for each country.
"""
return [
series[label]
for label in _build_tqdm_iterator(
series, verbose=False, total=len(series.columns)
)
]


class TrafficDataset(DatasetLoaderCSV):
"""
The raw data is in http://pems.dot.ca.gov. The data in this repo is a collection of 48 months (2015-2016)
hourly data from the California Department of Transportation. The data describes the road occupancy rates (between 0 and 1)
measured by different sensors on San Francisco Bay area freeways.

References
----------
.. [1] https://github.com/laiguokun/multivariate-time-series-data
"""

def __init__(self, multivariate: bool = True):
"""
Parameters
----------
multivariate: bool
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
"""
super().__init__(
metadata=DatasetLoaderMetadata(
"traffic.csv",
uri="https://raw.githubusercontent.com/unit8co/darts/Improvement/Add_new_datasets_617/datasets/traffic.csv",
hash="a2105f364ef70aec06c757304833f72a",
header_time="Date",
format_time="%Y-%m-%d %H:%M:%S",
freq="1H",
multivariate=multivariate,
)
)

def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
"""
Load the TrafficDataset dataset as a list of univariate timeseries, one for each ID.
"""
return [
series[label]
for label in _build_tqdm_iterator(
series, verbose=False, total=len(series.columns)
)
]


class WeatherDataset(DatasetLoaderCSV):
"""
Weather includes 21 indicators of weather, such as air
temperature, and humidity. Its data is recorded every
10 min for 2020 in Germany.

References
----------
.. [1] https://www.bgc-jena.mpg.de/wetter/
.. [2] https://arxiv.org/pdf/2205.13504.pdf
"""

def __init__(self, multivariate: bool = True):
"""
Parameters
----------
multivariate: bool
Whether to return a single multivariate timeseries - if False returns a list of univariate TimeSeries. Default is True.
"""
super().__init__(
metadata=DatasetLoaderMetadata(
"weather.csv",
uri="https://raw.githubusercontent.com/unit8co/darts/Improvement/Add_new_datasets_617/datasets/weather.csv",
hash="a2942a05638ba311bc7935bcc087a30f",
header_time="Date Time",
format_time="%d.%m.%Y %H:%M:%S",
freq="10min",
multivariate=multivariate,
)
)

def _to_multi_series(self, series: pd.DataFrame) -> List[TimeSeries]:
"""
Load the WeatherDataset dataset as a list of univariate timeseries, one for weather indicator.
"""
return [
series[label]
for label in _build_tqdm_iterator(
series, verbose=False, total=len(series.columns)
)
]
Loading