Skip to content

Commit

Permalink
Fix: removed input re-normalization by rin inside io_processor (uni…
Browse files Browse the repository at this point in the history
…t8co#2160)

* prevented input re-normalization by rin using .clone() inside `io_processor`

* Update CHANGELOG.md

---------

Co-authored-by: Dennis Bader <[email protected]>
  • Loading branch information
FourierMourier and dennisbader authored Jan 13, 2024
1 parent ea79679 commit cb724d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

**Fixed**

- Fixed a bug when using a `TorchForecastingModel` with `use_reversible_instance_norm=True` and predicting with `n > output_chunk_length`. The input normalized multiple times. [#2160](https://github.com/unit8co/darts/pull/2160) by [FourierMourier](https://github.com/FourierMourier).
### For developers of the library:

## [0.27.1](https://github.com/unit8co/darts/tree/0.27.1) (2023-12-10)
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def forward_wrapper(self, *args, **kwargs):

# x is input batch tuple which by definition has the past features in the first element starting with the
# first n target features
x: Tuple = args[0][0]
# assuming `args[0][0]` is torch.Tensor we could clone it to prevent target re-normalization
x: Tuple = args[0][0].clone()
# apply reversible instance normalization
x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
# run the forward pass
Expand Down

0 comments on commit cb724d1

Please sign in to comment.