@@ -181,22 +181,23 @@ def __init__(
181
181
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
182
182
Default: ``torch.nn.MSELoss()``.
183
183
likelihood
184
- The likelihood model to be used for probabilistic forecasts.
184
+ One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
185
+ probabilistic forecasts. Default: ``None``.
185
186
optimizer_cls
186
- The PyTorch optimizer class to be used (default : ``torch.optim.Adam``) .
187
+ The PyTorch optimizer class to be used. Default : ``torch.optim.Adam``.
187
188
optimizer_kwargs
188
189
Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}``
189
190
for specifying a learning rate). Otherwise the default values of the selected ``optimizer_cls``
190
- will be used.
191
+ will be used. Default: ``None``.
191
192
lr_scheduler_cls
192
193
Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds
193
- to using a constant learning rate.
194
+ to using a constant learning rate. Default: ``None``.
194
195
lr_scheduler_kwargs
195
- Optionally, some keyword arguments for the PyTorch learning rate scheduler.
196
+ Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
196
197
batch_size
197
- Number of time series (input and output sequences) used in each training pass.
198
+ Number of time series (input and output sequences) used in each training pass. Default: ``32``.
198
199
n_epochs
199
- Number of epochs over which to train the model.
200
+ Number of epochs over which to train the model. Default: ``100``.
200
201
model_name
201
202
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
202
203
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
@@ -205,13 +206,13 @@ def __init__(
205
206
``"2021-06-14_09:53:32_torch_model_run_44607"``.
206
207
work_dir
207
208
Path of the working directory, where to save checkpoints and Tensorboard summaries.
208
- (default : current working directory) .
209
+ Default : current working directory.
209
210
log_tensorboard
210
211
If set, use Tensorboard to log the different parameters. The logs will be located in:
211
- ``"{work_dir}/darts_logs/{model_name}/logs/"``.
212
+ ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``.
212
213
nr_epochs_val_period
213
214
Number of epochs to wait before evaluating the validation loss (if a validation
214
- ``TimeSeries`` is passed to the :func:`fit()` method).
215
+ ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``.
215
216
torch_device_str
216
217
Optionally, a string indicating the torch device to use. By default, ``torch_device_str`` is ``None``
217
218
which will run on CPU. Set it to ``"cuda"`` to use all available GPUs or ``"cuda:i"`` to only use
@@ -232,21 +233,21 @@ def __init__(
232
233
https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#select-gpu-devices
233
234
force_reset
234
235
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
235
- be discarded).
236
+ be discarded). Default: ``False``.
236
237
save_checkpoints
237
238
Whether or not to automatically save the untrained model and checkpoints from training.
238
239
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
239
240
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
240
241
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
241
- :func:`save_model()` and loaded using :func:`load_model()`.
242
+ :func:`save_model()` and loaded using :func:`load_model()`. Default: ``False``.
242
243
add_encoders
243
244
A large number of past and future covariates can be automatically generated with `add_encoders`.
244
245
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
245
246
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
246
247
transform the generated covariates. This happens all under one hood and only needs to be specified at
247
248
model creation.
248
249
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
249
- ``add_encoders``. An example showing some of ``add_encoders`` features:
250
+ ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
250
251
251
252
.. highlight:: python
252
253
.. code-block:: python
@@ -262,14 +263,15 @@ def __init__(
262
263
random_state
263
264
Control the randomness of the weights initialization. Check this
264
265
`link <https://scikit-learn.org/stable/glossary.html#term-random_state>`_ for more details.
266
+ Default: ``None``.
265
267
pl_trainer_kwargs
266
268
By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets
267
269
that performs the training, validation and prediction processes. These presets include automatic
268
270
checkpointing, tensorboard logging, setting the torch device and more.
269
271
With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer
270
272
object. Check the `PL Trainer documentation
271
273
<https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`_ for more information about the
272
- supported kwargs.
274
+ supported kwargs. Default: ``None``.
273
275
With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts'
274
276
:class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process.
275
277
The model will stop training early if the validation loss `val_loss` does not improve beyond
@@ -298,7 +300,7 @@ def __init__(
298
300
parameter ``trainer`` in :func:`fit()` and :func:`predict()`.
299
301
show_warnings
300
302
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
301
- your forecasting use case.
303
+ your forecasting use case. Default: ``False``.
302
304
"""
303
305
super ().__init__ (** self ._extract_torch_model_params (** self .model_params ))
304
306
0 commit comments