Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Vision DataModules #400

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d5d4dda
Add BaseDataModule
Nov 24, 2020
4876be0
Add pre-commit hooks
Nov 24, 2020
9476033
Refactor cifar10_datamodule
Nov 24, 2020
0651131
Move torchvision warning
Nov 24, 2020
e1f6238
Refactor binary_mnist_datamodule
Nov 24, 2020
5b9e2fd
Refactor fashion_mnist_datamodule
Nov 24, 2020
9ab0640
Fix errors
Nov 24, 2020
b7840bf
Remove VisionDataset type hint so CI base testing does not fail (torc…
Nov 24, 2020
4395be3
Implement Nate's suggestions
Nov 25, 2020
e82b243
Remove train and eval batch size because it brakes a lot of tests
Nov 25, 2020
8e9ae04
Properly add transforms to train and val dataset
Nov 25, 2020
790d6e0
Add num_samples property to cifar10 dm
Nov 25, 2020
7c9b3ce
Add tesats and docs
Nov 25, 2020
d86f432
Fix flake8 and codafactor issue
Nov 25, 2020
e5e69e4
Update changelog
Nov 27, 2020
1d821c2
Fix isort
Dec 9, 2020
a7d6bd4
Add typing
Dec 15, 2020
1d9fa44
Rename to VisionDataModule
Dec 15, 2020
6b9bdcb
Remove transform_lib type annotation
Dec 15, 2020
8ae4907
suggestions
Borda Dec 16, 2020
4de2ea9
Apply suggestions from code review
Borda Dec 16, 2020
716dedf
Apply suggestions from code review
Borda Dec 17, 2020
54114c5
Add flags from #388 to API
Dec 17, 2020
be6fb25
Make tests work
Dec 17, 2020
a55ad63
Merge branch 'master' into feature/395_refactor-vision-dms
chris-clem Dec 17, 2020
25382a2
Move _TORCHVISION_AVAILABLE check
Dec 17, 2020
4902c1c
Update changelog
Dec 17, 2020
2d29a1f
Merge branch 'master' into feature/395_refactor-vision-dms
Dec 17, 2020
26eafdf
Merge remote-tracking branch 'origin/feature/395_refactor-vision-dms'…
Dec 17, 2020
875f13a
Fix CI base testing
Dec 17, 2020
aac84ab
Fix CI base testing
Dec 17, 2020
1d5f0b5
Merge remote-tracking branch 'origin/feature/395_refactor-vision-dms'…
Dec 17, 2020
4cf20b6
Apply suggestions from code review
akihironitta Dec 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add num_samples property to cifar10 dm
  • Loading branch information
Christoph Clement authored and Borda committed Dec 16, 2020
commit 790d6e0919728a34bf33fb7792b24402f831cc2a
10 changes: 5 additions & 5 deletions pl_bolts/datamodules/base_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ def _split_dataset(self, dataset, train=True):

if train:
return dataset_train
else:
return dataset_val
return dataset_val

def _get_splits(self, len_dataset):
if isinstance(self.val_split, int):
train_len = len_dataset - self.val_split
return [train_len, self.val_split]

splits = [train_len, self.val_split]
elif isinstance(self.val_split, float):
val_len = int(self.val_split * len_dataset)
train_len = len_dataset - val_len
return [train_len, val_len]
splits = [train_len, val_len]

return splits

def default_transforms(self):
return transform_lib.ToTensor()
Expand Down
12 changes: 11 additions & 1 deletion pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,17 @@ def __init__(
*args,
**kwargs,
)
self.num_samples = 60000 - val_split

@property
def num_samples(self):
len_dataset = 60000
if isinstance(self.val_split, int):
train_len = len_dataset - self.val_split
elif isinstance(self.val_split, float):
val_len = int(self.val_split * len_dataset)
train_len = len_dataset - val_len

return train_len

@property
def num_classes(self):
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_classic_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_logistic_regression_model(tmpdir, datadir):

model = LogisticRegression(input_dim=28 * 28, num_classes=10, learning_rate=0.001)
model.prepare_data = dm.prepare_data
model.setup = dm.setup
model.train_dataloader = dm.train_dataloader
model.val_dataloader = dm.val_dataloader
model.test_dataloader = dm.test_dataloader
Expand Down