Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HNiTS rename to match model naming convection #1000

Merged
merged 5 commits into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Model | Univariate | Multivariate | Probabilistic | Multiple-series training | P
`RNNModel` (incl. LSTM and GRU); equivalent to DeepAR in its probabilistic version | ✅ | ✅ | ✅ | ✅ | | ✅ | [DeepAR paper](https://arxiv.org/abs/1704.04110)
`BlockRNNModel` (incl. LSTM and GRU) | ✅ | ✅ | ✅ | ✅ | ✅ | |
`NBEATSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-BEATS paper](https://arxiv.org/abs/1905.10437)
`NHiTS` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886)
`NHiTSModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [N-HiTS paper](https://arxiv.org/abs/2201.12886)
`TCNModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [TCN paper](https://arxiv.org/abs/1803.01271), [DeepTCN paper](https://arxiv.org/abs/1906.04397), [blog post](https://medium.com/unit8-machine-learning-publication/temporal-convolutional-networks-and-forecasting-5ce1b6e97ce4)
`TransformerModel` | ✅ | ✅ | ✅ | ✅ | ✅ | |
`TFTModel` (Temporal Fusion Transformer) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [TFT paper](https://arxiv.org/pdf/1912.09363.pdf), [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/latest/models.html)
Expand Down
2 changes: 1 addition & 1 deletion darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
try:
from darts.models.forecasting.block_rnn_model import BlockRNNModel
from darts.models.forecasting.nbeats import NBEATSModel
from darts.models.forecasting.nhits import NHiTS
from darts.models.forecasting.nhits import NHiTSModel
from darts.models.forecasting.rnn_model import RNNModel
from darts.models.forecasting.tcn_model import TCNModel
from darts.models.forecasting.tft_model import TFTModel
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def forward(self, x_in: Tuple):
return y


class NHiTS(PastCovariatesTorchModel):
class NHiTSModel(PastCovariatesTorchModel):
def __init__(
self,
input_chunk_length: int,
Expand Down
18 changes: 9 additions & 9 deletions darts/tests/models/forecasting/test_nbeats_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

try:
from darts.models.forecasting.nbeats import NBEATSModel
from darts.models.forecasting.nhits import NHiTS
from darts.models.forecasting.nhits import NHiTSModel

TORCH_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_creation(self):
)

with self.assertRaises(ValueError):
NHiTS(
NHiTSModel(
input_chunk_length=1,
output_chunk_length=1,
num_stacks=3,
Expand All @@ -50,7 +50,7 @@ def test_fit(self):
large_ts = tg.constant_timeseries(length=100, value=1000)
small_ts = tg.constant_timeseries(length=100, value=10)

for model_cls in [NBEATSModel, NHiTS]:
for model_cls in [NBEATSModel, NHiTSModel]:
# Test basic fit and predict
model = model_cls(
input_chunk_length=1,
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_multivariate(self):
tg.linear_timeseries(length=100, start_value=0, end_value=0.5)
)

for model_cls in [NBEATSModel, NHiTS]:
for model_cls in [NBEATSModel, NHiTSModel]:
model = model_cls(
input_chunk_length=3,
output_chunk_length=1,
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_nhits_sampling_sizes(self):
with self.assertRaises(ValueError):

# wrong number of coeffs for stacks and blocks
NHiTS(
NHiTSModel(
input_chunk_length=1,
output_chunk_length=1,
num_stacks=1,
Expand All @@ -140,7 +140,7 @@ def test_nhits_sampling_sizes(self):
n_freq_downsample=((1,), (1,)),
)
with self.assertRaises(ValueError):
NHiTS(
NHiTSModel(
input_chunk_length=1,
output_chunk_length=1,
num_stacks=2,
Expand All @@ -150,7 +150,7 @@ def test_nhits_sampling_sizes(self):
)

# it shouldn't fail with the right number of coeffs
_ = NHiTS(
_ = NHiTSModel(
input_chunk_length=1,
output_chunk_length=1,
num_stacks=2,
Expand All @@ -160,7 +160,7 @@ def test_nhits_sampling_sizes(self):
)

# default freqs should be such that last one is 1
model = NHiTS(
model = NHiTSModel(
input_chunk_length=1,
output_chunk_length=1,
num_stacks=2,
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_logtensorboard(self):
def test_activation_fns(self):
ts = tg.constant_timeseries(length=50, value=10)

for model_cls in [NBEATSModel, NHiTS]:
for model_cls in [NBEATSModel, NHiTSModel]:
model = model_cls(
input_chunk_length=1,
output_chunk_length=1,
Expand Down