Skip to content

Commit bfbeb7c

Browse files
dumjaxdennisbader
andauthored
Fix/warning msg rnn (#1674)
* warning * warning * adapt warning --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent ad944fc commit bfbeb7c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

darts/models/forecasting/rnn_model.py

+6
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ def __init__(
380380
"""
381381
# create copy of model parameters
382382
model_kwargs = {key: val for key, val in self.model_params.items()}
383+
384+
if model_kwargs.get("output_chunk_length") is not None:
385+
logger.warning(
386+
"ignoring user defined `output_chunk_length`. RNNModel uses a fixed `output_chunk_length=1`."
387+
)
388+
383389
model_kwargs["output_chunk_length"] = 1
384390

385391
super().__init__(**self._extract_torch_model_params(**model_kwargs))

0 commit comments

Comments
 (0)