Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run on CPU not on gpu #722

Closed
ElVictorious opened this issue Jan 4, 2022 · 1 comment · Fixed by #740
Closed

Run on CPU not on gpu #722

ElVictorious opened this issue Jan 4, 2022 · 1 comment · Fixed by #740
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@ElVictorious
Copy link

Hi,

First, congrats for the amazing job with this repo.

I experienced that the temporal fusion transformer do not works with GPU but it works well with cpus. I used 3 different computers and I had the same problem, even with colab. Same issue..

Thanks in advance for your help

model = TFTModel( input_chunk_length=INLEN,
output_chunk_length=N_FC,
hidden_size=HIDDEN,
lstm_layers=LSTMLAYERS,
num_attention_heads=ATTHEADS,
dropout=DROPOUT,
batch_size=BATCH,
n_epochs=EPOCHS,
likelihood=QuantileRegression(quantiles=QUANTILES),
# loss_fn=MSELoss(),
random_state=RAND,
force_reset=True)

model.fit( ts_ttrain,
future_covariates=tcov,
verbose=True)

RuntimeError Traceback (most recent call last)
in
15 model.fit( ts_ttrain,
16 future_covariates=tcov,
---> 17 verbose=True)

9 frames
/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py in decorator(self, *args, **kwargs)
63 with fork_rng():
64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65 return decorated(self, *args, **kwargs)
66 return decorator

/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, verbose, epochs, max_samples_per_ts, num_loader_workers)
477 logger.info('Train dataset contains {} samples.'.format(len(train_dataset)))
478
--> 479 self.fit_from_dataset(train_dataset, val_dataset, verbose, epochs, num_loader_workers)
480
481 @Property

/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py in decorator(self, *args, **kwargs)
63 with fork_rng():
64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65 return decorated(self, *args, **kwargs)
66 return decorator

/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in fit_from_dataset(self, train_dataset, val_dataset, verbose, epochs, num_loader_workers)
591
592 # Train model
--> 593 self._train(train_loader, val_loader, tb_writer, verbose, train_num_epochs)
594
595 # Close tensorboard writer

/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in _train(self, train_loader, val_loader, tb_writer, verbose, epochs)
886 self.model.train()
887 train_batch = self._batch_to_device(train_batch)
--> 888 output = self._produce_train_output(train_batch[:-1])
889 target = train_batch[-1] # By convention target is always the last element returned by datasets
890 loss = self._compute_loss(output, target)

/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/tft_model.py in _produce_train_output(self, input_batch)
800
801 def _produce_train_output(self, input_batch: Tuple):
--> 802 return self.model(input_batch)
803
804 def predict(self, n, *args, **kwargs):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/tft_model.py in forward(self, x)
447
448 # run local lstm encoder
--> 449 encoder_out, (hidden, cell) = self.lstm_encoder(input=embeddings_varying_encoder, hx=(input_hidden, input_cell))
450
451 # run local lstm decoder

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
690 if batch_sizes is None:
691 result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 692 self.dropout, self.training, self.bidirectional, self.batch_first)
693 else:
694 result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,

RuntimeError: rnn: hx is not contiguous

@ElVictorious ElVictorious added bug Something isn't working triage Issue waiting for triaging labels Jan 4, 2022
@dennisbader
Copy link
Collaborator

I could reproduce the error on colab.
This happens with lstm_layers > 1. Can you try with lstm_layers=1?

I assume the non-contiguity comes from expanding the hidden state tensors for the LSTM layers.
I will test if we can fix this with contiguous()

h3ik0th added a commit to h3ik0th/TFT_darts that referenced this issue Jan 25, 2022
On GPUs (not CPUs), Darts can throw an error if the LSTM layers of the TFT model are set higher than 1. This notebook limits LSTM layers to 1 whereas the notebook 2g5 used 2 layers.  
The Darts team worked on the issue in January 2022: unit8co/darts#722
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants