Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pytorch lightning #702

Merged
merged 59 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f7afe81
first draft proposal for pytorch lightning integration
dennisbader Dec 17, 2021
c026bf7
further cleanup
dennisbader Dec 20, 2021
dbef703
removed unnecessary file
dennisbader Dec 21, 2021
706fd35
fix for multiple-TS
dennisbader Dec 22, 2021
3ee8873
moved prediction timeseries generation back to TorchForecastingModel
dennisbader Dec 22, 2021
93b4bf2
support for custom trainer in fit()
dennisbader Dec 22, 2021
8c1b231
removed unused methods from TorchForecastingModel
dennisbader Dec 24, 2021
333cd5f
checkpoint loading now correctly resumes training
dennisbader Jan 11, 2022
ed42108
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 12, 2022
facba8c
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 15, 2022
df6b8d5
rewrote TorchForecastingModel
dennisbader Jan 15, 2022
9939422
rewrote TFTModel
dennisbader Jan 15, 2022
7d3f24d
rewrote rnn models
dennisbader Jan 15, 2022
e3bb9c2
rewrote nbeats models
dennisbader Jan 15, 2022
4cb7ec5
rewrote tcn model
dennisbader Jan 15, 2022
0ec8245
rewrote transformer model
dennisbader Jan 15, 2022
8408273
removed unused import
dennisbader Jan 15, 2022
2b36b84
resolve failing tests part 1
dennisbader Jan 17, 2022
d4db950
resolve failing tests part 2
dennisbader Jan 18, 2022
d3350d3
adapted the way how model parameters are saved
dennisbader Jan 23, 2022
1c16b72
moved TFTModel predict method into TorchForecastingModel subclass
dennisbader Jan 23, 2022
51b38b1
further simplification of model calls
dennisbader Jan 23, 2022
2905024
integrated ProbabilisticTorchForecastingModel into PLForecastingModule
dennisbader Jan 23, 2022
21ee164
integrated _produce_predict_output into PLForecastingModule
dennisbader Jan 23, 2022
8c3e94c
reintegrated original random state handling
dennisbader Jan 23, 2022
75f7194
removed unused pl random state wrapper function
dennisbader Jan 23, 2022
7102c6d
use OrderedDict for savety in model parameter extraction
dennisbader Jan 23, 2022
f48b61c
made TFM and PLFM paramater extraction generic
dennisbader Jan 29, 2022
7e7eccf
added types for variables in TFM init
dennisbader Jan 29, 2022
93ec255
made predictions deterministic for same fit predict process for non-l…
dennisbader Jan 29, 2022
83e67d6
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 29, 2022
0774c6e
fix flake8 issues
dennisbader Jan 29, 2022
e196c97
fix flake8 issues part 2
dennisbader Jan 29, 2022
4c983fb
added pytorch-lightning to torch requirements
dennisbader Feb 2, 2022
57d11cf
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 2, 2022
6277292
fixed loading models with wrong precision
dennisbader Feb 5, 2022
84cbd61
fixed is_probabilistic()
dennisbader Feb 5, 2022
4e0c8b5
fixed failing tests for epoch count tracker
dennisbader Feb 6, 2022
1ddd5ec
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 6, 2022
4c3e268
removed input/output_chunk_length from TorchForecastingModel __init__
dennisbader Feb 9, 2022
2eb240f
unit tests save models to temp dir
dennisbader Feb 9, 2022
2dd7841
added documentation for ModelMeta
dennisbader Feb 9, 2022
984ea88
apply suggestions from PR review part 1
dennisbader Feb 9, 2022
771fbdf
deprecated `torch_device_str`
dennisbader Feb 9, 2022
634524f
updated optimizer docs
dennisbader Feb 9, 2022
d6274a8
updated retrain warning
dennisbader Feb 9, 2022
fbf05d2
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 13, 2022
0ba7f65
made PLMixedCovariatesModule more generic
dennisbader Feb 13, 2022
970db09
added docs
dennisbader Feb 13, 2022
11e7681
added PTL trainer unit tests
dennisbader Feb 13, 2022
7a084ac
update model docs
dennisbader Feb 13, 2022
c8b4bff
fixed broken url in TFM and covariates userguide
dennisbader Feb 13, 2022
f0e4e30
removed input/output chunk length from PL modules
dennisbader Feb 13, 2022
3b7f846
relaxed pytorch-lightning requirement
dennisbader Feb 13, 2022
7290775
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 13, 2022
43803b9
isort
dennisbader Feb 13, 2022
fa5a6db
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 15, 2022
73e3ff4
isort part 2
dennisbader Feb 15, 2022
0d6c105
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 15, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' into feat/pytorch_lightning
  • Loading branch information
dennisbader committed Feb 13, 2022
commit 7290775eff5be667c3b37811e489670060baaeb6
8 changes: 4 additions & 4 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
-------------------------------
"""

import torch.nn as nn
import torch
from typing import List, Optional, Tuple, Union

from typing import List, Optional, Union, Tuple
import torch
import torch.nn as nn

from darts.logging import raise_if_not, get_logger
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
Tensor containing the prediction at the last time step of the sequence.
"""

super(_BlockRNNModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
28 changes: 8 additions & 20 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,16 @@
to obtain forecasts for a desired number of time stamps into the future.
"""
import copy
from collections import OrderedDict
from typing import Optional, Tuple, Union, Any, Callable, Dict, List, Sequence
from itertools import product
import inspect
import time
from abc import ABC, ABCMeta, abstractmethod
from collections import OrderedDict
from itertools import product
from random import sample
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import time

from darts.timeseries import TimeSeries
from darts.logging import get_logger, raise_log, raise_if_not, raise_if
from darts.utils import (
_build_tqdm_iterator,
_with_sanity_checks,
_historical_forecasts_general_checks,
_parallel_apply,
)
from darts.utils.timeseries_generation import (
_generate_new_dates,
_build_forecast_series,
)
import inspect

from darts import metrics
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
Expand All @@ -47,7 +32,10 @@
_parallel_apply,
_with_sanity_checks,
)
from darts.utils.timeseries_generation import _generate_index
from darts.utils.timeseries_generation import (
_build_forecast_series,
_generate_new_dates,
)

logger = get_logger(__name__)

Expand Down Expand Up @@ -101,7 +89,7 @@ def __call__(cls, *args, **kwargs):
cls._model_call = all_params

# 6) call model
return super(ModelMeta, cls).__call__(**all_params)
return super().__call__(**all_params)


class ForecastingModel(ABC, metaclass=ModelMeta):
Expand Down
6 changes: 2 additions & 4 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
-------
"""

from typing import NewType, Union, List, Tuple
from enum import Enum
from typing import List, NewType, Optional, Tuple, Union
from typing import List, NewType, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from numpy.random import RandomState

from darts.logging import get_logger, raise_log, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
Expand Down Expand Up @@ -358,7 +356,7 @@ def __init__(
Tensor containing the output of the NBEATS module.

"""
super(_NBEATSModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
9 changes: 5 additions & 4 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
-------------------------
"""

import torch.nn as nn
from typing import Optional, Sequence, Tuple, Union

import torch
from typing import Sequence, Optional, Union, Tuple
from darts.timeseries import TimeSeries
import torch.nn as nn

from darts.logging import raise_if_not, get_logger
from darts.models.forecasting.pl_forecasting_module import PLDualCovariatesModule
from darts.models.forecasting.torch_forecasting_model import DualCovariatesTorchModel
from darts.timeseries import TimeSeries
from darts.utils.data import DualCovariatesShiftedDataset, TrainingDataset

logger = get_logger(__name__)
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
"""

# RNNModule doesn't really need input and output_chunk_length for PLModule
super(_RNNModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
11 changes: 5 additions & 6 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
"""

import math
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Sequence, Tuple
from darts.timeseries import TimeSeries
from darts.utils.data import PastCovariatesShiftedDataset

from darts.logging import raise_if_not, get_logger
from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.timeseries import TimeSeries
from darts.utils.data import PastCovariatesShiftedDataset

logger = get_logger(__name__)

Expand Down Expand Up @@ -182,7 +181,7 @@ def __init__(
leading up to the first prediction, all in chronological order.
"""

super(_TCNModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
21 changes: 3 additions & 18 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,27 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

import torch
from numpy.random import RandomState
from torch import nn
from torch.nn import LSTM as _LSTM

from darts import TimeSeries

from darts.logging import get_logger, raise_if_not, raise_if
from darts.utils.likelihood_models import QuantileRegression, Likelihood
from darts.utils.data import (
TrainingDataset,
MixedCovariatesSequentialDataset,
MixedCovariatesTrainingDataset,
MixedCovariatesInferenceDataset,
)
from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.models.forecasting.tft_submodels import (
_GateAddNorm,
_GatedResidualNetwork,
_InterpretableMultiHeadAttention,
_VariableSelectionNetwork,
)
from darts.models.forecasting.torch_forecasting_model import (
MixedCovariatesTorchModel,
TorchParametricProbabilisticForecastingModel,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.utils.data import (
MixedCovariatesInferenceDataset,
MixedCovariatesSequentialDataset,
MixedCovariatesTrainingDataset,
TrainingDataset,
)
from darts.utils.likelihood_models import Likelihood, QuantileRegression
from darts.utils.torch import random_method

logger = get_logger(__name__)

Expand Down Expand Up @@ -101,7 +86,7 @@ def __init__(
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.
"""

super(_TFTModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
53 changes: 26 additions & 27 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,33 @@
forecasting models.
"""

import datetime
import inspect
import os
import shutil
import inspect
import datetime
import numpy as np
from glob import glob

from abc import ABC, abstractmethod
from glob import glob
from typing import Optional, Dict, Tuple, Union, Sequence, List

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from torch import Tensor
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.timeseries import TimeSeries
from darts.utils.data.encoders import SequentialEncoder
from darts.utils.likelihood_models import Likelihood
from darts.utils.torch import random_method
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule
from darts.logging import (
raise_if_not,
get_logger,
raise_log,
raise_if,
raise_deprecation_warning,
raise_if,
raise_if_not,
raise_log,
suppress_lightning_warnings,
)
from darts.utils.data.training_dataset import (
TrainingDataset,
PastCovariatesTrainingDataset,
FutureCovariatesTrainingDataset,
DualCovariatesTrainingDataset,
MixedCovariatesTrainingDataset,
SplitCovariatesTrainingDataset,
)
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule
from darts.timeseries import TimeSeries
from darts.utils.data.inference_dataset import (
DualCovariatesInferenceDataset,
FutureCovariatesInferenceDataset,
Expand All @@ -72,6 +59,17 @@
PastCovariatesSequentialDataset,
SplitCovariatesSequentialDataset,
)
from darts.utils.data.training_dataset import (
DualCovariatesTrainingDataset,
FutureCovariatesTrainingDataset,
MixedCovariatesTrainingDataset,
PastCovariatesTrainingDataset,
SplitCovariatesTrainingDataset,
TrainingDataset,
)
from darts.utils.data.encoders import SequentialEncoder
from darts.utils.likelihood_models import Likelihood
from darts.utils.torch import random_method

DEFAULT_DARTS_FOLDER = "darts_logs"
CHECKPOINTS_FOLDER = "checkpoints"
Expand Down Expand Up @@ -401,7 +399,8 @@ def _extract_torch_model_params(**kwargs):
get_params.remove("self")
return {kwarg: kwargs.get(kwarg) for kwarg in get_params if kwarg in kwargs}

def _extract_pl_module_params(self, **kwargs):
@staticmethod
def _extract_pl_module_params(**kwargs):
"""Extract params from model creation to set up PLForecastingModule (the actual torch.nn.Module)"""
get_params = list(
inspect.signature(PLForecastingModule.__init__).parameters.keys()
Expand Down
5 changes: 2 additions & 3 deletions darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
"""

import math
from typing import Optional, Tuple, Union
from typing import Optional, Tuple

import torch
import torch.nn as nn
from typing import Optional, Tuple

from darts.logging import get_logger
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(
Tensor containing the prediction at the last time step of the sequence.
"""

super(_TransformerModule, self).__init__(**kwargs)
super().__init__(**kwargs)

# required for all modules -> saves hparams for checkpoints
self.save_hyperparameters()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from copy import deepcopy
from unittest.mock import patch, ANY
from unittest.mock import ANY, patch

import numpy as np

from darts.dataprocessing.transformers import Scaler
from darts.datasets import AirPassengersDataset
Expand Down
7 changes: 4 additions & 3 deletions darts/tests/models/forecasting/test_ptl_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import tempfile
import shutil
import tempfile

import numpy as np
import pytorch_lightning as pl

from darts.tests.base_test_class import DartsBaseTestClass
from darts.logging import get_logger
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils.timeseries_generation import linear_timeseries

logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion requirements/torch.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.8.0
pytorch-lightning>=1.5.0
torch>=1.8.0
You are viewing a condensed version of this merge commit. You can view the full changes here.