Skip to content

Commit 07c6a4d

Browse files
authored
HNiTS rename to match model naming convection (#1000)
* renamed nhits to match naming convention of other models * renamed nhits to match naming convention of other models * renamed nhits to match naming convention of other models
1 parent c8a82d7 commit 07c6a4d

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ Model | Univariate | Multivariate | Probabilistic | Multiple-series training | P
141141
`RNNModel` (incl. LSTM and GRU); equivalent to DeepAR in its probabilistic version | ✅ | ✅ | ✅ | ✅ | | ✅ | [DeepAR paper](https://arxiv.org/abs/1704.04110)
142142
`BlockRNNModel` (incl. LSTM and GRU) | ✅ | ✅ | ✅ | ✅ | ✅ | |
143143
`NBEATSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-BEATS paper](https://arxiv.org/abs/1905.10437)
144-
`NHiTS` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886)
144+
`NHiTSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886)
145145
`TCNModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [TCN paper](https://arxiv.org/abs/1803.01271), [DeepTCN paper](https://arxiv.org/abs/1906.04397), [blog post](https://medium.com/unit8-machine-learning-publication/temporal-convolutional-networks-and-forecasting-5ce1b6e97ce4)
146146
`TransformerModel` | ✅ | ✅ | ✅ | ✅ | ✅ | |
147147
`TFTModel` (Temporal Fusion Transformer) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [TFT paper](https://arxiv.org/pdf/1912.09363.pdf), [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/latest/models.html)

darts/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
try:
4444
from darts.models.forecasting.block_rnn_model import BlockRNNModel
4545
from darts.models.forecasting.nbeats import NBEATSModel
46-
from darts.models.forecasting.nhits import NHiTS
46+
from darts.models.forecasting.nhits import NHiTSModel
4747
from darts.models.forecasting.rnn_model import RNNModel
4848
from darts.models.forecasting.tcn_model import TCNModel
4949
from darts.models.forecasting.tft_model import TFTModel

darts/models/forecasting/nhits.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def forward(self, x_in: Tuple):
456456
return y
457457

458458

459-
class NHiTS(PastCovariatesTorchModel):
459+
class NHiTSModel(PastCovariatesTorchModel):
460460
def __init__(
461461
self,
462462
input_chunk_length: int,

darts/tests/models/forecasting/test_nbeats_nhits.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
try:
1313
from darts.models.forecasting.nbeats import NBEATSModel
14-
from darts.models.forecasting.nhits import NHiTS
14+
from darts.models.forecasting.nhits import NHiTSModel
1515

1616
TORCH_AVAILABLE = True
1717
except ImportError:
@@ -39,7 +39,7 @@ def test_creation(self):
3939
)
4040

4141
with self.assertRaises(ValueError):
42-
NHiTS(
42+
NHiTSModel(
4343
input_chunk_length=1,
4444
output_chunk_length=1,
4545
num_stacks=3,
@@ -50,7 +50,7 @@ def test_fit(self):
5050
large_ts = tg.constant_timeseries(length=100, value=1000)
5151
small_ts = tg.constant_timeseries(length=100, value=10)
5252

53-
for model_cls in [NBEATSModel, NHiTS]:
53+
for model_cls in [NBEATSModel, NHiTSModel]:
5454
# Test basic fit and predict
5555
model = model_cls(
5656
input_chunk_length=1,
@@ -88,7 +88,7 @@ def test_multivariate(self):
8888
tg.linear_timeseries(length=100, start_value=0, end_value=0.5)
8989
)
9090

91-
for model_cls in [NBEATSModel, NHiTS]:
91+
for model_cls in [NBEATSModel, NHiTSModel]:
9292
model = model_cls(
9393
input_chunk_length=3,
9494
output_chunk_length=1,
@@ -131,7 +131,7 @@ def test_nhits_sampling_sizes(self):
131131
with self.assertRaises(ValueError):
132132

133133
# wrong number of coeffs for stacks and blocks
134-
NHiTS(
134+
NHiTSModel(
135135
input_chunk_length=1,
136136
output_chunk_length=1,
137137
num_stacks=1,
@@ -140,7 +140,7 @@ def test_nhits_sampling_sizes(self):
140140
n_freq_downsample=((1,), (1,)),
141141
)
142142
with self.assertRaises(ValueError):
143-
NHiTS(
143+
NHiTSModel(
144144
input_chunk_length=1,
145145
output_chunk_length=1,
146146
num_stacks=2,
@@ -150,7 +150,7 @@ def test_nhits_sampling_sizes(self):
150150
)
151151

152152
# it shouldn't fail with the right number of coeffs
153-
_ = NHiTS(
153+
_ = NHiTSModel(
154154
input_chunk_length=1,
155155
output_chunk_length=1,
156156
num_stacks=2,
@@ -160,7 +160,7 @@ def test_nhits_sampling_sizes(self):
160160
)
161161

162162
# default freqs should be such that last one is 1
163-
model = NHiTS(
163+
model = NHiTSModel(
164164
input_chunk_length=1,
165165
output_chunk_length=1,
166166
num_stacks=2,
@@ -189,7 +189,7 @@ def test_logtensorboard(self):
189189
def test_activation_fns(self):
190190
ts = tg.constant_timeseries(length=50, value=10)
191191

192-
for model_cls in [NBEATSModel, NHiTS]:
192+
for model_cls in [NBEATSModel, NHiTSModel]:
193193
model = model_cls(
194194
input_chunk_length=1,
195195
output_chunk_length=1,

0 commit comments

Comments
 (0)