Skip to content

Commit ddc84a2

Browse files
authored
fixed non contiguous error when using lstm_layers > 1 on gpu (#740)
1 parent adde52a commit ddc84a2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

darts/models/forecasting/tft_model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,15 @@ def forward(self, x) -> Dict[str, torch.Tensor]:
482482

483483
# LSTM
484484
# calculate initial state
485-
input_hidden = self.static_context_hidden_encoder_grn(static_embedding).expand(
486-
self.lstm_layers, -1, -1
485+
input_hidden = (
486+
self.static_context_hidden_encoder_grn(static_embedding)
487+
.expand(self.lstm_layers, -1, -1)
488+
.contiguous()
487489
)
488-
input_cell = self.static_context_cell_encoder_grn(static_embedding).expand(
489-
self.lstm_layers, -1, -1
490+
input_cell = (
491+
self.static_context_cell_encoder_grn(static_embedding)
492+
.expand(self.lstm_layers, -1, -1)
493+
.contiguous()
490494
)
491495

492496
# run local lstm encoder

0 commit comments

Comments
 (0)