Skip to content

Commit 8b2df75

Browse files
authored
Fix/ptl version 200 (#1651)
1 parent 95f4d4f commit 8b2df75

File tree

4 files changed

+83
-24
lines changed

4 files changed

+83
-24
lines changed

darts/models/forecasting/pl_forecasting_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# Check whether we are running pytorch-lightning >= 1.6.0 or not:
2323
tokens = pl.__version__.split(".")
24-
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6
24+
pl_160_or_above = int(tokens[0]) > 1 or int(tokens[0]) == 1 and int(tokens[1]) >= 6
2525

2626

2727
class PLForecastingModule(pl.LightningModule, ABC):

darts/models/forecasting/torch_forecasting_model.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import datetime
2121
import inspect
2222
import os
23+
import re
2324
import shutil
2425
import sys
2526
from abc import ABC, abstractmethod
@@ -85,6 +86,10 @@
8586

8687
logger = get_logger(__name__)
8788

89+
# Check whether we are running pytorch-lightning >= 2.0.0 or not:
90+
tokens = pl.__version__.split(".")
91+
pl_200_or_above = int(tokens[0]) >= 2
92+
8893

8994
def _get_checkpoint_folder(work_dir, model_name):
9095
return os.path.join(work_dir, model_name, CHECKPOINTS_FOLDER)
@@ -427,25 +432,49 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None:
427432
dtype = self.train_sample[0].dtype
428433
if np.issubdtype(dtype, np.float32):
429434
logger.info("Time series values are 32-bits; casting model to float32.")
430-
precision = 32
435+
precision = "32" if not pl_200_or_above else "32-true"
431436
elif np.issubdtype(dtype, np.float64):
432437
logger.info("Time series values are 64-bits; casting model to float64.")
433-
precision = 64
438+
precision = "64" if not pl_200_or_above else "64-true"
439+
else:
440+
raise_log(
441+
ValueError(
442+
f"Invalid time series data type `{dtype}`. Cast your data to `np.float32` "
443+
f"or `np.float64`, e.g. with `TimeSeries.astype(np.float32)`."
444+
),
445+
logger,
446+
)
447+
precision_int = int(re.findall(r"\d+", str(precision))[0])
434448

435449
precision_user = (
436450
self.trainer_params.get("precision", None)
437451
if trainer is None
438452
else trainer.precision
439453
)
454+
if precision_user is not None:
455+
# currently, we only support float 64 and 32
456+
valid_precisions = (
457+
["64", "32"] if not pl_200_or_above else ["64-true", "32-true"]
458+
)
459+
if str(precision_user) not in valid_precisions:
460+
raise_log(
461+
ValueError(
462+
f"Invalid user-defined trainer_kwarg `precision={precision_user}`. "
463+
f"Use one of ({valid_precisions})"
464+
),
465+
logger,
466+
)
467+
precision_user_int = int(re.findall(r"\d+", str(precision_user))[0])
468+
else:
469+
precision_user_int = None
440470

441471
raise_if(
442-
precision_user is not None and int(precision_user) != precision,
443-
f"User-defined trainer_kwarg `precision={precision_user}` does not match dtype: `{dtype}` of the "
472+
precision_user is not None and precision_user_int != precision_int,
473+
f"User-defined trainer_kwarg `precision='{precision_user}'` does not match dtype: `{dtype}` of the "
444474
f"underlying TimeSeries. Set `precision` to `{precision}` or cast your data to `{precision_user}"
445-
f"` with `TimeSeries.astype(np.float{precision_user})`.",
475+
f"` with `TimeSeries.astype(np.float{precision_user_int})`.",
446476
logger,
447477
)
448-
449478
self.trainer_params["precision"] = precision
450479

451480
# we need to save the initialized TorchForecastingModel as PyTorch-Lightning only saves module checkpoints

darts/tests/models/forecasting/test_probabilistic_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
{
128128
"input_chunk_length": 10,
129129
"output_chunk_length": 5,
130-
"n_epochs": 5,
130+
"n_epochs": 10,
131131
"random_state": 0,
132132
"likelihood": GaussianLikelihood(),
133133
},

darts/tests/models/forecasting/test_ptl_trainer.py

+46-16
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@ def test_custom_trainer_setup(self):
9999
self.assertEqual(trainer.max_epochs, model.epochs_trained)
100100

101101
def test_builtin_extended_trainer(self):
102-
invalid_trainer_kwarg = {"precisionn": 32}
103-
104-
# error will be raised at training time
102+
# wrong precision parameter name
105103
with self.assertRaises(TypeError):
104+
invalid_trainer_kwarg = {"precisionn": "32-true"}
106105
model = RNNModel(
107106
12,
108107
"RNN",
@@ -113,20 +112,51 @@ def test_builtin_extended_trainer(self):
113112
)
114113
model.fit(self.series, epochs=1)
115114

116-
valid_trainer_kwargs = {
117-
"precision": 32,
118-
}
115+
# flaot 16 not supported
116+
with self.assertRaises(ValueError):
117+
invalid_trainer_kwarg = {"precision": "16-mixed"}
118+
model = RNNModel(
119+
12,
120+
"RNN",
121+
10,
122+
10,
123+
random_state=42,
124+
pl_trainer_kwargs=invalid_trainer_kwarg,
125+
)
126+
model.fit(self.series.astype(np.float16), epochs=1)
119127

120-
# valid parameters shouldn't raise error
121-
model = RNNModel(
122-
12,
123-
"RNN",
124-
10,
125-
10,
126-
random_state=42,
127-
pl_trainer_kwargs=valid_trainer_kwargs,
128-
)
129-
model.fit(self.series, epochs=1)
128+
# precision value doesn't match `series` dtype
129+
with self.assertRaises(ValueError):
130+
invalid_trainer_kwarg = {"precision": "64-true"}
131+
model = RNNModel(
132+
12,
133+
"RNN",
134+
10,
135+
10,
136+
random_state=42,
137+
pl_trainer_kwargs=invalid_trainer_kwarg,
138+
)
139+
model.fit(self.series.astype(np.float32), epochs=1)
140+
141+
for precision, precision_int in zip(["64-true", "32-true"], [64, 32]):
142+
valid_trainer_kwargs = {
143+
"precision": precision,
144+
}
145+
146+
# valid parameters shouldn't raise error
147+
model = RNNModel(
148+
12,
149+
"RNN",
150+
10,
151+
10,
152+
random_state=42,
153+
pl_trainer_kwargs=valid_trainer_kwargs,
154+
)
155+
ts_dtype = getattr(np, f"float{precision_int}")
156+
model.fit(self.series.astype(ts_dtype), epochs=1)
157+
preds = model.predict(n=3)
158+
assert model.trainer.precision == precision
159+
assert preds.dtype == ts_dtype
130160

131161
def test_custom_callback(self):
132162
class CounterCallback(pl.callbacks.Callback):

0 commit comments

Comments
 (0)