Skip to content

Commit

Permalink
binary mnist dm fix (#377)
Browse files Browse the repository at this point in the history
* download mnist

* typo

* update docs example

* remove prepare_data in example

* imports

* imports

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
annikabrundyn and Borda authored Nov 19, 2020
1 parent 21f13f3 commit fa5a944
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BinaryMNISTDataModule(LightningDataModule):
Trainer().fit(model, dm)
"""

name = 'mnist'
name = "binary_mnist"

def __init__(
self,
Expand Down Expand Up @@ -89,8 +89,8 @@ def prepare_data(self):
"""
Saves MNIST files to data_dir
"""
MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())
BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self, batch_size=32, transforms=None):
"""
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/models/self_supervised/simclr/simclr_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

import pytorch_lightning as pl

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.simclr.transforms import SimCLRFinetuneTransform
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
stl10_normalization,
cifar10_normalization
)


def cli_main(): # pragma: no-cover
from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule, CIFAR10DataModule
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule

pl.seed_everything(1234)

Expand Down
8 changes: 3 additions & 5 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
import os
from argparse import ArgumentParser
from typing import Callable, Optional

Expand All @@ -8,17 +7,16 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F

from pytorch_lightning.utilities import AMPType
from torch import nn
from torch.optim.optimizer import Optimizer

from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.transforms.dataset_normalizations import (
stl10_normalization,
cifar10_normalization,
imagenet_normalization
imagenet_normalization,
stl10_normalization,
)


Expand Down Expand Up @@ -359,8 +357,8 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule, ImagenetDataModule

parser = ArgumentParser()

Expand Down

0 comments on commit fa5a944

Please sign in to comment.