Skip to content

Commit

Permalink
Update torch_forecasting_model.py (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeadie authored Feb 14, 2024
1 parent a184762 commit 213ba81
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,6 @@ def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
logger=logger,
)

dim_component = self.past_covariate_series.n_components
(
past_target,
past_covariates,
Expand All @@ -2057,13 +2056,9 @@ def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
# I think these have to do with future covariates (which isn't supported in Dlinear)
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in self.train_sample]

n_past_covs = (
past_covariates.shape[dim_component] if past_covariates is not None else 0
)

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=dim_component,
dim=2, # Shape is (1, lookback_size, no. of variates (in either target or series))
)

input_sample = [input_past.float(), static_covariates.float() if static_covariates is not None else None]
Expand Down

0 comments on commit 213ba81

Please sign in to comment.