diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 066aa3cd0a..b9cd811335 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index f7b9f9963f..43ff3ebb6a 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule