11
11
12
12
try :
13
13
from darts .models .forecasting .nbeats import NBEATSModel
14
- from darts .models .forecasting .nhits import NHiTS
14
+ from darts .models .forecasting .nhits import NHiTSModel
15
15
16
16
TORCH_AVAILABLE = True
17
17
except ImportError :
@@ -39,7 +39,7 @@ def test_creation(self):
39
39
)
40
40
41
41
with self .assertRaises (ValueError ):
42
- NHiTS (
42
+ NHiTSModel (
43
43
input_chunk_length = 1 ,
44
44
output_chunk_length = 1 ,
45
45
num_stacks = 3 ,
@@ -50,7 +50,7 @@ def test_fit(self):
50
50
large_ts = tg .constant_timeseries (length = 100 , value = 1000 )
51
51
small_ts = tg .constant_timeseries (length = 100 , value = 10 )
52
52
53
- for model_cls in [NBEATSModel , NHiTS ]:
53
+ for model_cls in [NBEATSModel , NHiTSModel ]:
54
54
# Test basic fit and predict
55
55
model = model_cls (
56
56
input_chunk_length = 1 ,
@@ -88,7 +88,7 @@ def test_multivariate(self):
88
88
tg .linear_timeseries (length = 100 , start_value = 0 , end_value = 0.5 )
89
89
)
90
90
91
- for model_cls in [NBEATSModel , NHiTS ]:
91
+ for model_cls in [NBEATSModel , NHiTSModel ]:
92
92
model = model_cls (
93
93
input_chunk_length = 3 ,
94
94
output_chunk_length = 1 ,
@@ -131,7 +131,7 @@ def test_nhits_sampling_sizes(self):
131
131
with self .assertRaises (ValueError ):
132
132
133
133
# wrong number of coeffs for stacks and blocks
134
- NHiTS (
134
+ NHiTSModel (
135
135
input_chunk_length = 1 ,
136
136
output_chunk_length = 1 ,
137
137
num_stacks = 1 ,
@@ -140,7 +140,7 @@ def test_nhits_sampling_sizes(self):
140
140
n_freq_downsample = ((1 ,), (1 ,)),
141
141
)
142
142
with self .assertRaises (ValueError ):
143
- NHiTS (
143
+ NHiTSModel (
144
144
input_chunk_length = 1 ,
145
145
output_chunk_length = 1 ,
146
146
num_stacks = 2 ,
@@ -150,7 +150,7 @@ def test_nhits_sampling_sizes(self):
150
150
)
151
151
152
152
# it shouldn't fail with the right number of coeffs
153
- _ = NHiTS (
153
+ _ = NHiTSModel (
154
154
input_chunk_length = 1 ,
155
155
output_chunk_length = 1 ,
156
156
num_stacks = 2 ,
@@ -160,7 +160,7 @@ def test_nhits_sampling_sizes(self):
160
160
)
161
161
162
162
# default freqs should be such that last one is 1
163
- model = NHiTS (
163
+ model = NHiTSModel (
164
164
input_chunk_length = 1 ,
165
165
output_chunk_length = 1 ,
166
166
num_stacks = 2 ,
@@ -189,7 +189,7 @@ def test_logtensorboard(self):
189
189
def test_activation_fns (self ):
190
190
ts = tg .constant_timeseries (length = 50 , value = 10 )
191
191
192
- for model_cls in [NBEATSModel , NHiTS ]:
192
+ for model_cls in [NBEATSModel , NHiTSModel ]:
193
193
model = model_cls (
194
194
input_chunk_length = 1 ,
195
195
output_chunk_length = 1 ,
0 commit comments