Skip to content

Commit

Permalink
[RFC] Add missing names to pl_bolts/datasets/__init__.py (#493)
Browse files Browse the repository at this point in the history
* Add missing names to __init__.py

* Apply isort

* Apply new importing to datamodules

* Apply new importing

* Add missing names to __init__.py

* Apply isort

* Apply new importing to datamodules

* Apply new importing

* num_workers

* isort

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
akihironitta and Borda authored Jan 19, 2021
1 parent 6b2136b commit f8affe9
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -136,7 +136,7 @@ def prepare_data(self):
To generate the meta.bin do the following:
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
path = '/path/to/folder/with/ILSVRC2012_devkit_t12.tar.gz/'
UnlabeledImagenet.generate_meta_bins(path)
"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets import KittiDataset
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets.concat_dataset import ConcatDataset
from pl_bolts.datasets import ConcatDataset
from pl_bolts.transforms.dataset_normalizations import stl10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
24 changes: 21 additions & 3 deletions pl_bolts/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10
from pl_bolts.datasets.concat_dataset import ConcatDataset
from pl_bolts.datasets.dummy_dataset import (
DummyDataset,
DummyDetectionDataset,
RandomDataset,
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet
from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin

__all__ = [
"RandomDictStringDataset",
"RandomDictDataset",
"RandomDataset",
"LightDataset",
"CIFAR10",
"TrialCIFAR10",
"ConcatDataset",
"DummyDataset",
"DummyDetectionDataset",
"RandomDataset",
"RandomDictDataset",
"RandomDictStringDataset",
"extract_archive",
"parse_devkit_archive",
"UnlabeledImagenet",
"KittiDataset",
"BinaryMNIST",
"CIFAR10Mixed",
"SSLDatasetMixin",
]
2 changes: 1 addition & 1 deletion pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import Tensor

from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.datasets import LightDataset
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from torch.utils.data import random_split

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed
from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_byol(tmpdir, datadir):
def test_amdim(tmpdir, datadir):
seed_everything()

model = AMDIM(data_dir=datadir, batch_size=2, online_ft=True, encoder='resnet18')
model = AMDIM(data_dir=datadir, batch_size=2, online_ft=True, encoder='resnet18', num_workers=0)
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
loss = trainer.progress_bar_dict['loss']
Expand Down

0 comments on commit f8affe9

Please sign in to comment.