diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 0eef89aca4..49b7f7de38 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -21,7 +21,7 @@ # Check whether we are running pytorch-lightning >= 1.6.0 or not: tokens = pl.__version__.split(".") -pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6 +pl_160_or_above = int(tokens[0]) > 1 or int(tokens[0]) == 1 and int(tokens[1]) >= 6 class PLForecastingModule(pl.LightningModule, ABC): diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index b361d4f998..8528a87f72 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -20,6 +20,7 @@ import datetime import inspect import os +import re import shutil import sys from abc import ABC, abstractmethod @@ -85,6 +86,10 @@ logger = get_logger(__name__) +# Check whether we are running pytorch-lightning >= 2.0.0 or not: +tokens = pl.__version__.split(".") +pl_200_or_above = int(tokens[0]) >= 2 + def _get_checkpoint_folder(work_dir, model_name): return os.path.join(work_dir, model_name, CHECKPOINTS_FOLDER) @@ -427,25 +432,49 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None: dtype = self.train_sample[0].dtype if np.issubdtype(dtype, np.float32): logger.info("Time series values are 32-bits; casting model to float32.") - precision = 32 + precision = "32" if not pl_200_or_above else "32-true" elif np.issubdtype(dtype, np.float64): logger.info("Time series values are 64-bits; casting model to float64.") - precision = 64 + precision = "64" if not pl_200_or_above else "64-true" + else: + raise_log( + ValueError( + f"Invalid time series data type `{dtype}`. Cast your data to `np.float32` " + f"or `np.float64`, e.g. with `TimeSeries.astype(np.float32)`." + ), + logger, + ) + precision_int = int(re.findall(r"\d+", str(precision))[0]) precision_user = ( self.trainer_params.get("precision", None) if trainer is None else trainer.precision ) + if precision_user is not None: + # currently, we only support float 64 and 32 + valid_precisions = ( + ["64", "32"] if not pl_200_or_above else ["64-true", "32-true"] + ) + if str(precision_user) not in valid_precisions: + raise_log( + ValueError( + f"Invalid user-defined trainer_kwarg `precision={precision_user}`. " + f"Use one of ({valid_precisions})" + ), + logger, + ) + precision_user_int = int(re.findall(r"\d+", str(precision_user))[0]) + else: + precision_user_int = None raise_if( - precision_user is not None and int(precision_user) != precision, - f"User-defined trainer_kwarg `precision={precision_user}` does not match dtype: `{dtype}` of the " + precision_user is not None and precision_user_int != precision_int, + f"User-defined trainer_kwarg `precision='{precision_user}'` does not match dtype: `{dtype}` of the " f"underlying TimeSeries. Set `precision` to `{precision}` or cast your data to `{precision_user}" - f"` with `TimeSeries.astype(np.float{precision_user})`.", + f"` with `TimeSeries.astype(np.float{precision_user_int})`.", logger, ) - self.trainer_params["precision"] = precision # we need to save the initialized TorchForecastingModel as PyTorch-Lightning only saves module checkpoints diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 82af108536..08d7a49712 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -127,7 +127,7 @@ { "input_chunk_length": 10, "output_chunk_length": 5, - "n_epochs": 5, + "n_epochs": 10, "random_state": 0, "likelihood": GaussianLikelihood(), }, diff --git a/darts/tests/models/forecasting/test_ptl_trainer.py b/darts/tests/models/forecasting/test_ptl_trainer.py index b714ae3207..d4cb082178 100644 --- a/darts/tests/models/forecasting/test_ptl_trainer.py +++ b/darts/tests/models/forecasting/test_ptl_trainer.py @@ -99,10 +99,9 @@ def test_custom_trainer_setup(self): self.assertEqual(trainer.max_epochs, model.epochs_trained) def test_builtin_extended_trainer(self): - invalid_trainer_kwarg = {"precisionn": 32} - - # error will be raised at training time + # wrong precision parameter name with self.assertRaises(TypeError): + invalid_trainer_kwarg = {"precisionn": "32-true"} model = RNNModel( 12, "RNN", @@ -113,20 +112,51 @@ def test_builtin_extended_trainer(self): ) model.fit(self.series, epochs=1) - valid_trainer_kwargs = { - "precision": 32, - } + # flaot 16 not supported + with self.assertRaises(ValueError): + invalid_trainer_kwarg = {"precision": "16-mixed"} + model = RNNModel( + 12, + "RNN", + 10, + 10, + random_state=42, + pl_trainer_kwargs=invalid_trainer_kwarg, + ) + model.fit(self.series.astype(np.float16), epochs=1) - # valid parameters shouldn't raise error - model = RNNModel( - 12, - "RNN", - 10, - 10, - random_state=42, - pl_trainer_kwargs=valid_trainer_kwargs, - ) - model.fit(self.series, epochs=1) + # precision value doesn't match `series` dtype + with self.assertRaises(ValueError): + invalid_trainer_kwarg = {"precision": "64-true"} + model = RNNModel( + 12, + "RNN", + 10, + 10, + random_state=42, + pl_trainer_kwargs=invalid_trainer_kwarg, + ) + model.fit(self.series.astype(np.float32), epochs=1) + + for precision, precision_int in zip(["64-true", "32-true"], [64, 32]): + valid_trainer_kwargs = { + "precision": precision, + } + + # valid parameters shouldn't raise error + model = RNNModel( + 12, + "RNN", + 10, + 10, + random_state=42, + pl_trainer_kwargs=valid_trainer_kwargs, + ) + ts_dtype = getattr(np, f"float{precision_int}") + model.fit(self.series.astype(ts_dtype), epochs=1) + preds = model.predict(n=3) + assert model.trainer.precision == precision + assert preds.dtype == ts_dtype def test_custom_callback(self): class CounterCallback(pl.callbacks.Callback):