Skip to content

Commit

Permalink
cpc stl-10 finetune fix (#173)
Browse files Browse the repository at this point in the history
* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests

* enable dp tests
  • Loading branch information
williamFalcon authored Sep 1, 2020
1 parent cf38307 commit ee39028
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.batch_size = batch_size
self.seed = seed
self.num_unlabeled_samples = 100000 - unlabeled_val_split
self.num_labeled_samples = 5000 - train_val_split
self.labeled_val_split = 200

@property
def num_classes(self):
Expand Down Expand Up @@ -240,7 +240,7 @@ def train_dataloader_labeled(self):
dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(dataset,
[train_length - self.num_labeled_samples, self.num_labeled_samples],
[train_length - self.labeled_val_split, self.labeled_val_split],
generator=torch.Generator().manual_seed(self.seed))
loader = DataLoader(
dataset_train,
Expand All @@ -259,7 +259,7 @@ def val_dataloader_labeled(self):
transform=transforms)
labeled_length = len(dataset)
_, labeled_val = random_split(dataset,
[labeled_length - self.num_labeled_samples, self.num_labeled_samples],
[labeled_length - self.labeled_val_split, self.labeled_val_split],
generator=torch.Generator().manual_seed(self.seed))

loader = DataLoader(
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/self_supervised/cpc/cpc_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def cli_main(): # pragma: no-cover

elif args.dataset == 'stl10':
dm = STL10DataModule.from_argparse_args(args)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed
dm.train_dataloader = dm.train_dataloader_labeled
dm.val_dataloader = dm.val_dataloader_labeled
dm.train_transforms = CPCTrainTransformsSTL10()
dm.val_transforms = CPCEvalTransformsSTL10()
dm.test_transforms = CPCEvalTransformsSTL10()
Expand Down

0 comments on commit ee39028

Please sign in to comment.