Skip to content

Commit 074be2a

Browse files
authored
fix/TFTModel_flask (#745)
1 parent 334b0d8 commit 074be2a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

darts/models/forecasting/tft_model.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def forward(self, x) -> Dict[str, torch.Tensor]:
348348
input dimensions: (n_samples, n_time_steps, n_variables)
349349
"""
350350

351-
dim_samples, dim_time, dim_variable, dim_loss = 0, 1, 2, 3
351+
dim_samples, dim_time, dim_variable = 0, 1, 2
352352
past_target, past_covariates, historic_future_covariates, future_covariates = x
353353

354354
batch_size = past_target.shape[dim_samples]
@@ -450,12 +450,13 @@ def forward(self, x) -> Dict[str, torch.Tensor]:
450450
device=past_target.device,
451451
)
452452

453-
# this is only to interpret the output
454-
static_covariate_var = torch.zeros(
455-
(past_target.shape[0], 0),
456-
dtype=past_target.dtype,
457-
device=past_target.device,
458-
)
453+
# # TODO: implement below when static covariates are supported
454+
# # this is only to interpret the output
455+
# static_covariate_var = torch.zeros(
456+
# (past_target.shape[0], 0),
457+
# dtype=past_target.dtype,
458+
# device=past_target.device,
459+
# )
459460

460461
if future_covariates is None and static_covariates is None:
461462
raise NotImplementedError("make zero tensor if future covariates is None")

0 commit comments

Comments
 (0)