From cb724d15e84b718543580b35f24443419354331c Mon Sep 17 00:00:00 2001 From: FourierMourier <91980559+FourierMourier@users.noreply.github.com> Date: Sat, 13 Jan 2024 14:58:42 +0300 Subject: [PATCH] Fix: removed input re-normalization by rin inside `io_processor` (#2160) * prevented input re-normalization by rin using .clone() inside `io_processor` * Update CHANGELOG.md --------- Co-authored-by: Dennis Bader --- CHANGELOG.md | 2 +- darts/models/forecasting/pl_forecasting_module.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95f52f1a76..45fa2eb617 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** **Fixed** - +- Fixed a bug when using a `TorchForecastingModel` with `use_reversible_instance_norm=True` and predicting with `n > output_chunk_length`. The input normalized multiple times. [#2160](https://github.com/unit8co/darts/pull/2160) by [FourierMourier](https://github.com/FourierMourier). ### For developers of the library: ## [0.27.1](https://github.com/unit8co/darts/tree/0.27.1) (2023-12-10) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 7ade7eaac9..ab98ee59c2 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -50,7 +50,8 @@ def forward_wrapper(self, *args, **kwargs): # x is input batch tuple which by definition has the past features in the first element starting with the # first n target features - x: Tuple = args[0][0] + # assuming `args[0][0]` is torch.Tensor we could clone it to prevent target re-normalization + x: Tuple = args[0][0].clone() # apply reversible instance normalization x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets]) # run the forward pass