@@ -1743,6 +1743,7 @@ def load_weights_from_checkpoint(
1743
1743
best : bool = True ,
1744
1744
strict : bool = True ,
1745
1745
load_encoders : bool = True ,
1746
+ skip_checks : bool = False ,
1746
1747
** kwargs ,
1747
1748
):
1748
1749
"""
@@ -1758,6 +1759,9 @@ def load_weights_from_checkpoint(
1758
1759
For manually saved model, consider using :meth:`load() <TorchForecastingModel.load()>` or
1759
1760
:meth:`load_weights() <TorchForecastingModel.load_weights()>` instead.
1760
1761
1762
+ Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
1763
+ and perform sanity checks on the model parameters.
1764
+
1761
1765
Parameters
1762
1766
----------
1763
1767
model_name
@@ -1777,6 +1781,9 @@ def load_weights_from_checkpoint(
1777
1781
load_encoders
1778
1782
If set, will load the encoders from the model to enable direct call of fit() or predict().
1779
1783
Default: ``True``.
1784
+ skip_checks
1785
+ If set, will disable the loading of the encoders and the sanity checks on model parameters
1786
+ (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
1780
1787
**kwargs
1781
1788
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
1782
1789
different device than the one from which it was saved.
@@ -1790,6 +1797,13 @@ def load_weights_from_checkpoint(
1790
1797
logger ,
1791
1798
)
1792
1799
1800
+ raise_if (
1801
+ skip_checks and load_encoders ,
1802
+ "`skip-checks` and `load_encoders` are mutually exclusive parameters and cannot be both "
1803
+ "set to `True`." ,
1804
+ logger ,
1805
+ )
1806
+
1793
1807
# use the name of the model being loaded with the saved weights
1794
1808
if model_name is None :
1795
1809
model_name = self .model_name
@@ -1816,39 +1830,6 @@ def load_weights_from_checkpoint(
1816
1830
1817
1831
ckpt_path = os .path .join (checkpoint_dir , file_name )
1818
1832
ckpt = torch .load (ckpt_path , ** kwargs )
1819
- ckpt_hyper_params = ckpt ["hyper_parameters" ]
1820
-
1821
- # verify that the arguments passed to the constructor match those of the checkpoint
1822
- # add_encoders is checked in _load_encoders()
1823
- skipped_params = list (
1824
- inspect .signature (TorchForecastingModel .__init__ ).parameters .keys ()
1825
- ) + [
1826
- "loss_fn" ,
1827
- "torch_metrics" ,
1828
- "optimizer_cls" ,
1829
- "optimizer_kwargs" ,
1830
- "lr_scheduler_cls" ,
1831
- "lr_scheduler_kwargs" ,
1832
- ]
1833
- for param_key , param_value in self .model_params .items ():
1834
- # TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers
1835
- if (
1836
- param_key in ckpt_hyper_params .keys ()
1837
- and param_key not in skipped_params
1838
- ):
1839
- # some parameters must be converted
1840
- if isinstance (ckpt_hyper_params [param_key ], list ) and not isinstance (
1841
- param_value , list
1842
- ):
1843
- param_value = [param_value ] * len (ckpt_hyper_params [param_key ])
1844
-
1845
- raise_if (
1846
- param_value != ckpt_hyper_params [param_key ],
1847
- f"The values of the hyper parameter { param_key } should be identical between "
1848
- f"the instantiated model ({ param_value } ) and the loaded checkpoint "
1849
- f"({ ckpt_hyper_params [param_key ]} ). Please adjust the model accordingly." ,
1850
- logger ,
1851
- )
1852
1833
1853
1834
# indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported
1854
1835
raise_if_not (
@@ -1867,17 +1848,32 @@ def load_weights_from_checkpoint(
1867
1848
]
1868
1849
self .train_sample = tuple (mock_train_sample )
1869
1850
1870
- # updating model attributes before self._init_model() which create new ckpt
1871
- tfm_save_file_path = os .path .join (tfm_save_file_dir , tfm_save_file_name )
1872
- with open (tfm_save_file_path , "rb" ) as tfm_save_file :
1873
- tfm_save : TorchForecastingModel = torch .load (
1874
- tfm_save_file , map_location = kwargs .get ("map_location" , None )
1875
- )
1851
+ if not skip_checks :
1852
+ # path to the tfm checkpoint (darts model, .pt extension)
1853
+ tfm_save_file_path = os .path .join (tfm_save_file_dir , tfm_save_file_name )
1854
+ if not os .path .exists (tfm_save_file_path ):
1855
+ raise_log (
1856
+ FileNotFoundError (
1857
+ f"Could not find { tfm_save_file_path } , necessary to load the encoders "
1858
+ f"and run sanity checks on the model parameters."
1859
+ ),
1860
+ logger ,
1861
+ )
1862
+
1863
+ # updating model attributes before self._init_model() which create new tfm ckpt
1864
+ with open (tfm_save_file_path , "rb" ) as tfm_save_file :
1865
+ tfm_save : TorchForecastingModel = torch .load (
1866
+ tfm_save_file , map_location = kwargs .get ("map_location" , None )
1867
+ )
1868
+
1876
1869
# encoders are necessary for direct inference
1877
1870
self .encoders , self .add_encoders = self ._load_encoders (
1878
1871
tfm_save , load_encoders
1879
1872
)
1880
1873
1874
+ # meaningful error message if parameters are incompatible with the ckpt weights
1875
+ self ._check_ckpt_parameters (tfm_save )
1876
+
1881
1877
# instanciate the model without having to call `fit_from_dataset`
1882
1878
self .model = self ._init_model ()
1883
1879
# cast model precision to correct type
@@ -1887,10 +1883,15 @@ def load_weights_from_checkpoint(
1887
1883
# update the fit_called attribute to allow for direct inference
1888
1884
self ._fit_called = True
1889
1885
1890
- def load_weights (self , path : str , load_encoders : bool = True , ** kwargs ):
1886
+ def load_weights (
1887
+ self , path : str , load_encoders : bool = True , skip_checks : bool = False , ** kwargs
1888
+ ):
1891
1889
"""
1892
1890
Loads the weights from a manually saved model (saved with :meth:`save() <TorchForecastingModel.save()>`).
1893
1891
1892
+ Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
1893
+ and perform sanity checks on the model parameters.
1894
+
1894
1895
Parameters
1895
1896
----------
1896
1897
path
@@ -1899,6 +1900,9 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
1899
1900
load_encoders
1900
1901
If set, will load the encoders from the model to enable direct call of fit() or predict().
1901
1902
Default: ``True``.
1903
+ skip_checks
1904
+ If set, will disable the loading of the encoders and the sanity checks on model parameters
1905
+ (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
1902
1906
**kwargs
1903
1907
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
1904
1908
different device than the one from which it was saved.
@@ -1916,6 +1920,7 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
1916
1920
self .load_weights_from_checkpoint (
1917
1921
file_name = path_ptl_ckpt ,
1918
1922
load_encoders = load_encoders ,
1923
+ skip_checks = skip_checks ,
1919
1924
** kwargs ,
1920
1925
)
1921
1926
@@ -2058,6 +2063,75 @@ def _load_encoders(
2058
2063
2059
2064
return new_encoders , new_add_encoders
2060
2065
2066
+ def _check_ckpt_parameters (self , tfm_save ):
2067
+ """
2068
+ Check that the positional parameters used to instantiate the new model loading the weights match those
2069
+ of the saved model, to return meaningful messages in case of discrepancies.
2070
+ """
2071
+ # parameters unrelated to the weights shape
2072
+ skipped_params = list (
2073
+ inspect .signature (TorchForecastingModel .__init__ ).parameters .keys ()
2074
+ ) + [
2075
+ "loss_fn" ,
2076
+ "torch_metrics" ,
2077
+ "optimizer_cls" ,
2078
+ "optimizer_kwargs" ,
2079
+ "lr_scheduler_cls" ,
2080
+ "lr_scheduler_kwargs" ,
2081
+ ]
2082
+ # model_params can be missing some kwargs
2083
+ params_to_check = set (tfm_save .model_params .keys ()).union (
2084
+ self .model_params .keys ()
2085
+ ) - set (skipped_params )
2086
+
2087
+ incorrect_params = []
2088
+ missing_params = []
2089
+ for param_key in params_to_check :
2090
+ # param was not used at loading model creation
2091
+ if param_key not in self .model_params .keys ():
2092
+ missing_params .append ((param_key , tfm_save .model_params [param_key ]))
2093
+ # new param was used at loading model creation
2094
+ elif param_key not in tfm_save .model_params .keys ():
2095
+ incorrect_params .append (
2096
+ (
2097
+ param_key ,
2098
+ None ,
2099
+ self .model_params [param_key ],
2100
+ )
2101
+ )
2102
+ # param was different at loading model creation
2103
+ elif self .model_params [param_key ] != tfm_save .model_params [param_key ]:
2104
+ # NOTE: for TFTModel, default is None but converted to `QuantileRegression()`
2105
+ incorrect_params .append (
2106
+ (
2107
+ param_key ,
2108
+ tfm_save .model_params [param_key ],
2109
+ self .model_params [param_key ],
2110
+ )
2111
+ )
2112
+
2113
+ # at least one discrepancy was detected
2114
+ if len (missing_params ) + len (incorrect_params ) > 0 :
2115
+ msg = [
2116
+ "The values of the hyper-parameters in the model and loaded checkpoint should be identical."
2117
+ ]
2118
+
2119
+ # warning messages formated to facilate copy-pasting
2120
+ if len (missing_params ) > 0 :
2121
+ msg += ["missing :" ]
2122
+ msg += [
2123
+ f" - { param } ={ exp_val } " for (param , exp_val ) in missing_params
2124
+ ]
2125
+
2126
+ if len (incorrect_params ) > 0 :
2127
+ msg += ["incorrect :" ]
2128
+ msg += [
2129
+ f" - found { param } ={ cur_val } , should be { param } ={ exp_val } "
2130
+ for (param , exp_val , cur_val ) in incorrect_params
2131
+ ]
2132
+
2133
+ raise_log (ValueError ("\n " .join (msg )), logger )
2134
+
2061
2135
def __getstate__ (self ):
2062
2136
# do not pickle the PyTorch LightningModule, and Trainer
2063
2137
return {k : v for k , v in self .__dict__ .items () if k not in TFM_ATTRS_NO_PICKLE }
0 commit comments