From 368fdbc1704039aa227c15f5f19eb933d1e5a5e1 Mon Sep 17 00:00:00 2001 From: gdevos010 Date: Thu, 9 Jun 2022 16:12:25 -0700 Subject: [PATCH 1/3] renamed nhits to match naming convention of other models --- darts/models/forecasting/nhits.py | 2 +- .../models/forecasting/test_nbeats_nhits.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index 93d45f9940..382245168c 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -456,7 +456,7 @@ def forward(self, x_in: Tuple): return y -class NHiTS(PastCovariatesTorchModel): +class NHiTSModel(PastCovariatesTorchModel): def __init__( self, input_chunk_length: int, diff --git a/darts/tests/models/forecasting/test_nbeats_nhits.py b/darts/tests/models/forecasting/test_nbeats_nhits.py index 36d4c5159a..b14e23b4c7 100644 --- a/darts/tests/models/forecasting/test_nbeats_nhits.py +++ b/darts/tests/models/forecasting/test_nbeats_nhits.py @@ -11,7 +11,7 @@ try: from darts.models.forecasting.nbeats import NBEATSModel - from darts.models.forecasting.nhits import NHiTS + from darts.models.forecasting.nhits import NHiTSModel TORCH_AVAILABLE = True except ImportError: @@ -39,7 +39,7 @@ def test_creation(self): ) with self.assertRaises(ValueError): - NHiTS( + NHiTSModel( input_chunk_length=1, output_chunk_length=1, num_stacks=3, @@ -50,7 +50,7 @@ def test_fit(self): large_ts = tg.constant_timeseries(length=100, value=1000) small_ts = tg.constant_timeseries(length=100, value=10) - for model_cls in [NBEATSModel, NHiTS]: + for model_cls in [NBEATSModel, NHiTSModel]: # Test basic fit and predict model = model_cls( input_chunk_length=1, @@ -88,7 +88,7 @@ def test_multivariate(self): tg.linear_timeseries(length=100, start_value=0, end_value=0.5) ) - for model_cls in [NBEATSModel, NHiTS]: + for model_cls in [NBEATSModel, NHiTSModel]: model = model_cls( input_chunk_length=3, output_chunk_length=1, @@ -131,7 +131,7 @@ def test_nhits_sampling_sizes(self): with self.assertRaises(ValueError): # wrong number of coeffs for stacks and blocks - NHiTS( + NHiTSModel( input_chunk_length=1, output_chunk_length=1, num_stacks=1, @@ -140,7 +140,7 @@ def test_nhits_sampling_sizes(self): n_freq_downsample=((1,), (1,)), ) with self.assertRaises(ValueError): - NHiTS( + NHiTSModel( input_chunk_length=1, output_chunk_length=1, num_stacks=2, @@ -150,7 +150,7 @@ def test_nhits_sampling_sizes(self): ) # it shouldn't fail with the right number of coeffs - _ = NHiTS( + _ = NHiTSModel( input_chunk_length=1, output_chunk_length=1, num_stacks=2, @@ -160,7 +160,7 @@ def test_nhits_sampling_sizes(self): ) # default freqs should be such that last one is 1 - model = NHiTS( + model = NHiTSModel( input_chunk_length=1, output_chunk_length=1, num_stacks=2, @@ -189,7 +189,7 @@ def test_logtensorboard(self): def test_activation_fns(self): ts = tg.constant_timeseries(length=50, value=10) - for model_cls in [NBEATSModel, NHiTS]: + for model_cls in [NBEATSModel, NHiTSModel]: model = model_cls( input_chunk_length=1, output_chunk_length=1, From dda3e2158f14995407c91e12775523ad348bcf9b Mon Sep 17 00:00:00 2001 From: gdevos010 Date: Thu, 9 Jun 2022 16:14:09 -0700 Subject: [PATCH 2/3] renamed nhits to match naming convention of other models --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d4ba858972..3e894b33a3 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ Model | Univariate | Multivariate | Probabilistic | Multiple-series training | P `RNNModel` (incl. LSTM and GRU); equivalent to DeepAR in its probabilistic version | ✅ | ✅ | ✅ | ✅ | | ✅ | [DeepAR paper](https://arxiv.org/abs/1704.04110) `BlockRNNModel` (incl. LSTM and GRU) | ✅ | ✅ | ✅ | ✅ | ✅ | | `NBEATSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-BEATS paper](https://arxiv.org/abs/1905.10437) -`NHiTS` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886) +`NHiTSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886) `TCNModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [TCN paper](https://arxiv.org/abs/1803.01271), [DeepTCN paper](https://arxiv.org/abs/1906.04397), [blog post](https://medium.com/unit8-machine-learning-publication/temporal-convolutional-networks-and-forecasting-5ce1b6e97ce4) `TransformerModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | `TFTModel` (Temporal Fusion Transformer) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [TFT paper](https://arxiv.org/pdf/1912.09363.pdf), [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/latest/models.html) From 9c1efe6a3bd4fd46769200bdcf0d63e06dfc9f4f Mon Sep 17 00:00:00 2001 From: gdevos010 Date: Thu, 9 Jun 2022 17:51:57 -0700 Subject: [PATCH 3/3] renamed nhits to match naming convention of other models --- darts/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/__init__.py b/darts/models/__init__.py index c24e095db7..387602e4f5 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -43,7 +43,7 @@ try: from darts.models.forecasting.block_rnn_model import BlockRNNModel from darts.models.forecasting.nbeats import NBEATSModel - from darts.models.forecasting.nhits import NHiTS + from darts.models.forecasting.nhits import NHiTSModel from darts.models.forecasting.rnn_model import RNNModel from darts.models.forecasting.tcn_model import TCNModel from darts.models.forecasting.tft_model import TFTModel