diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index d69cab2..9f58ae4 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -7,6 +7,7 @@ def test_uspsdataset0_6(): import h5py import numpy as np + from torchvision import transforms # Create a temporary directory (deleted after the test) with TemporaryDirectory() as tempdir: @@ -20,7 +21,13 @@ def test_uspsdataset0_6(): f["train/data"] = np.random.rand(10, 16 * 16) f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) - dataset = USPSDataset0_6(data_path=tempdir, train=True) + trans = transforms.Compose( + [ + transforms.Resize((16, 16)), # At least for USPS + transforms.ToTensor(), + ] + ) + dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans) assert len(dataset) == 10 data, target = dataset[0] assert data.shape == (1, 16, 16)