Skip to content

Commit

Permalink
Add transforms to dataloader test
Browse files Browse the repository at this point in the history
  • Loading branch information
salomaestro committed Feb 4, 2025
1 parent 3f2e4e2 commit d6999aa
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_uspsdataset0_6():

import h5py
import numpy as np
from torchvision import transforms

# Create a temporary directory (deleted after the test)
with TemporaryDirectory() as tempdir:
Expand All @@ -20,7 +21,13 @@ def test_uspsdataset0_6():
f["train/data"] = np.random.rand(10, 16 * 16)
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])

dataset = USPSDataset0_6(data_path=tempdir, train=True)
trans = transforms.Compose(
[
transforms.Resize((16, 16)), # At least for USPS
transforms.ToTensor(),
]
)
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
assert len(dataset) == 10
data, target = dataset[0]
assert data.shape == (1, 16, 16)
Expand Down

0 comments on commit d6999aa

Please sign in to comment.