diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index ca93196792..29f195168d 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -482,11 +482,15 @@ def forward(self, x) -> Dict[str, torch.Tensor]: # LSTM # calculate initial state - input_hidden = self.static_context_hidden_encoder_grn(static_embedding).expand( - self.lstm_layers, -1, -1 + input_hidden = ( + self.static_context_hidden_encoder_grn(static_embedding) + .expand(self.lstm_layers, -1, -1) + .contiguous() ) - input_cell = self.static_context_cell_encoder_grn(static_embedding).expand( - self.lstm_layers, -1, -1 + input_cell = ( + self.static_context_cell_encoder_grn(static_embedding) + .expand(self.lstm_layers, -1, -1) + .contiguous() ) # run local lstm encoder