From de02b381de70cf38686ec583a1f8c4d971448900 Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Wed, 24 Aug 2022 18:49:15 +0100 Subject: [PATCH 1/8] created ArrayDataset --- pl_bolts/datamodules/__init__.py | 3 +- pl_bolts/datamodules/sklearn_datamodule.py | 83 ---------------------- pl_bolts/datasets/__init__.py | 5 +- pl_bolts/datasets/array_dataset.py | 52 ++++++++++++++ pl_bolts/datasets/base_dataset.py | 32 ++++++++- pl_bolts/datasets/utils.py | 16 +++++ pl_bolts/utils/types.py | 6 ++ tests/datasets/test_array_dataset.py | 66 +++++++++++++++++ 8 files changed, 176 insertions(+), 87 deletions(-) create mode 100644 pl_bolts/datasets/array_dataset.py create mode 100644 pl_bolts/utils/types.py create mode 100644 tests/datasets/test_array_dataset.py diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 15515b2562..ab33730152 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -9,7 +9,7 @@ from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule from pl_bolts.datamodules.kitti_datamodule import KittiDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule -from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset +from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset from pl_bolts.datamodules.sr_datamodule import TVTDataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule @@ -31,7 +31,6 @@ "MNISTDataModule", "SklearnDataModule", "SklearnDataset", - "TensorDataset", "TVTDataModule", "SSLImagenetDataModule", "STL10DataModule", diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 56ba224b41..f86370517b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -2,9 +2,7 @@ from typing import Any, Tuple import numpy as np -import torch from pytorch_lightning import LightningDataModule -from torch import Tensor from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _SKLEARN_AVAILABLE @@ -65,50 +63,6 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: return x, y -@under_review() -class TensorDataset(Dataset): - """Prepare PyTorch tensor dataset for data loaders. - - Example: - >>> from pl_bolts.datamodules import TensorDataset - ... - >>> X = torch.rand(10, 3) - >>> y = torch.rand(10) - >>> dataset = TensorDataset(X, y) - >>> len(dataset) - 10 - """ - - def __init__(self, X: Tensor, y: Tensor, X_transform: Any = None, y_transform: Any = None) -> None: - """ - Args: - X: PyTorch tensor - y: PyTorch tensor - X_transform: Any transform that works with PyTorch tensors - y_transform: Any transform that works with PyTorch tensors - """ - super().__init__() - self.X = X - self.Y = y - self.X_transform = X_transform - self.y_transform = y_transform - - def __len__(self) -> int: - return len(self.X) - - def __getitem__(self, idx) -> Tuple[Tensor, Tensor]: - x = self.X[idx].float() - y = self.Y[idx] - - if self.X_transform: - x = self.X_transform(x) - - if self.y_transform: - y = self.y_transform(y) - - return x, y - - @under_review() class SklearnDataModule(LightningDataModule): """Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as @@ -241,40 +195,3 @@ def test_dataloader(self) -> DataLoader: pin_memory=self.pin_memory, ) return loader - - -# TODO: this seems to be wrong, something missing here, another inherit class? -# class TensorDataModule(SklearnDataModule): -# """ -# Automatically generates the train, validation and test splits for a PyTorch tensor dataset. They are set up as -# dataloaders for convenience. Optionally, you can pass in your own validation and test splits. -# -# Example: -# -# >>> from pl_bolts.datamodules import TensorDataModule -# >>> import torch -# ... -# >>> # create dataset -# >>> X = torch.rand(100, 3) -# >>> y = torch.rand(100) -# >>> loaders = TensorDataModule(X, y) -# ... -# >>> # train set -# >>> train_loader = loaders.train_dataloader(batch_size=10) -# >>> len(train_loader.dataset) -# 70 -# >>> len(train_loader) -# 7 -# >>> # validation set -# >>> val_loader = loaders.val_dataloader(batch_size=10) -# >>> len(val_loader.dataset) -# 20 -# >>> len(val_loader) -# 2 -# >>> # test set -# >>> test_loader = loaders.test_dataloader(batch_size=10) -# >>> len(test_loader.dataset) -# 10 -# >>> len(test_loader) -# 1 -# """ diff --git a/pl_bolts/datasets/__init__.py b/pl_bolts/datasets/__init__.py index c5c845c988..218299482d 100644 --- a/pl_bolts/datasets/__init__.py +++ b/pl_bolts/datasets/__init__.py @@ -1,6 +1,7 @@ import urllib -from pl_bolts.datasets.base_dataset import LightDataset +from pl_bolts.datasets.array_dataset import ArrayDataset +from pl_bolts.datasets.base_dataset import DataModel, LightDataset from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10 from pl_bolts.datasets.concat_dataset import ConcatDataset from pl_bolts.datasets.dummy_dataset import ( @@ -17,6 +18,8 @@ from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin __all__ = [ + "ArrayDataset", + "DataModel", "LightDataset", "CIFAR10", "TrialCIFAR10", diff --git a/pl_bolts/datasets/array_dataset.py b/pl_bolts/datasets/array_dataset.py new file mode 100644 index 0000000000..c4babc2d3a --- /dev/null +++ b/pl_bolts/datasets/array_dataset.py @@ -0,0 +1,52 @@ +from typing import Tuple + +from pytorch_lightning.utilities import exceptions +from torch.utils.data import Dataset + +from pl_bolts.datasets.base_dataset import ARRAYS, DataModel + + +class ArrayDataset(Dataset): + """Dataset wrapping tensors, lists, numpy arrays. + + Any number of ARRAYS can be inputted into the dataset. The ARRAYS are transformed on each `__getitem__`. When + transforming, please refrain from chaning the hape of ARRAYS in the first demension. + + Attributes: + data_models: Sequence of data models. + + Raises: + MisconfigurationException: if there is a shape mismatch between arrays in the first dimension. + + Example: + >>> from pl_bolts.datasets import ArrayDataset, DataModel + >>> from pl_bolts.datasets.utils import to_tensor + + >>> features = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3]], transform=to_tensor) + >>> target = DataModel(data=[1, 0, 0], transform=to_tensor) + + >>> ds = ArrayDataset(features, target) + >>> len(ds) + 3 + """ + + def __init__(self, *data_models: DataModel) -> None: + """Initialises class and checks if arrays are the same shape in the first dimension.""" + self.data_models = data_models + + if not self._equal_size(): + raise exceptions.MisconfigurationException("Shape mismatch between arrays in the first dimension") + + def __len__(self) -> int: + return len(self.data_models[0].data) + + def __getitem__(self, idx: int) -> Tuple[ARRAYS, ...]: + return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models) + + def _equal_size(self) -> bool: + """Checks the size of the data_models are equal in the first dimension. + + Returns: + bool: True if size of data_models are equal in the first dimension. False, if not. + """ + return all(len(data_model.data) == len(self.data_models[0].data) for data_model in self.data_models) diff --git a/pl_bolts/datasets/base_dataset.py b/pl_bolts/datasets/base_dataset.py index f03ae80d2b..2a0ef9986a 100644 --- a/pl_bolts/datasets/base_dataset.py +++ b/pl_bolts/datasets/base_dataset.py @@ -2,13 +2,15 @@ import os import urllib.request from abc import ABC -from typing import Sequence, Tuple +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Tuple from urllib.error import HTTPError from torch import Tensor from torch.utils.data import Dataset from pl_bolts.utils.stability import under_review +from pl_bolts.utils.types import ARRAYS @under_review() @@ -58,3 +60,31 @@ def _download_from_url(self, base_url: str, data_folder: str, file_name: str): urllib.request.urlretrieve(url, fpath) except HTTPError as err: raise RuntimeError(f"Failed download from {url}") from err + + +@dataclass +class DataModel: + """Data model dataclass. + + Ties together data and callable transforms. + + Attributes: + data: Sequence of indexables. + transform: Callable to transform data. + """ + + data: ARRAYS + transform: Optional[Callable] = None + + def process(self, subset: ARRAYS) -> ARRAYS: + """Transforms a subset of data. + + Args: + subset: Sequence of indexables. + + Returns: + data: Transformed data if transform is not None. + """ + if self.transform is not None: + subset = self.transform(subset) + return subset diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py index a53b98fe0d..2a4ef11d3b 100644 --- a/pl_bolts/datasets/utils.py +++ b/pl_bolts/datasets/utils.py @@ -1,3 +1,6 @@ +from typing import List + +import torch from torch.utils.data.dataset import random_split from pl_bolts.datasets.sr_celeba_dataset import SRCelebA @@ -39,3 +42,16 @@ def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True) return (dataset_train, dataset_val, dataset_test) + + +def to_tensor(integers: List[int]) -> torch.Tensor: + """Takes a list of integers and returns a tensor. + + This function serves as a use case for the ArrayDataset. + + Args: + integers: List of integers + Returns: + A tensor of the integers + """ + return torch.tensor(integers) diff --git a/pl_bolts/utils/types.py b/pl_bolts/utils/types.py new file mode 100644 index 0000000000..c291a12c22 --- /dev/null +++ b/pl_bolts/utils/types.py @@ -0,0 +1,6 @@ +from typing import List, Union + +import numpy as np +import torch + +ARRAYS = Union[torch.Tensor, np.ndarray, List[Union[float, int]], List[List[Union[float, int]]]] diff --git a/tests/datasets/test_array_dataset.py b/tests/datasets/test_array_dataset.py new file mode 100644 index 0000000000..2c713a0009 --- /dev/null +++ b/tests/datasets/test_array_dataset.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest +import torch +from pytorch_lightning.utilities import exceptions + +from pl_bolts.datasets import ArrayDataset, DataModel +from pl_bolts.datasets.utils import to_tensor + + +def add_one(integers: np.ndarray) -> np.ndarray: + output = [] + for data in integers: + output.append(data + 1) + return np.array(output) + + +def add_one_to_tensor(integers: np.ndarray) -> torch.Tensor: + integers = add_one(integers) + return to_tensor(integers) + + +class TestArrayDataset: + @pytest.fixture + def array_dataset(self): + features_1 = DataModel( + data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=add_one_to_tensor + ) + target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor) + + features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]])) + target_2 = DataModel(data=[1, 0, 1, 1]) + return ArrayDataset(features_1, target_1, features_2, target_2) + + def test_len(self, array_dataset): + assert len(array_dataset) == 4 + + def test_getitem_with_transforms(self, array_dataset): + assert len(array_dataset[0]) == 4 + assert len(array_dataset[1]) == 4 + assert len(array_dataset[2]) == 4 + assert len(array_dataset[3]) == 4 + torch.testing.assert_close(array_dataset[0][0], torch.tensor([2, 1, 0, 3])) + torch.testing.assert_close(array_dataset[0][1], torch.tensor(1)) + np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1])) + assert array_dataset[0][3] == 1 + torch.testing.assert_close(array_dataset[1][0], torch.tensor([2, 1, -1, -0])) + torch.testing.assert_close(array_dataset[1][1], torch.tensor(0)) + np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1])) + assert array_dataset[1][3] == 0 + torch.testing.assert_close(array_dataset[2][0], torch.tensor([3, 6, 1, 4])) + torch.testing.assert_close(array_dataset[2][1], torch.tensor(0)) + np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3])) + assert array_dataset[2][3] == 1 + torch.testing.assert_close(array_dataset[3][0], torch.tensor([-6, 2, 3, 3])) + torch.testing.assert_close(array_dataset[3][1], torch.tensor(1)) + np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2])) + assert array_dataset[3][3] == 1 + + def test__equal_size_true(self, array_dataset): + assert array_dataset._equal_size() is True + + def test__equal_size_false(self): + features = DataModel(data=[[1, 0, 1]]) + target = DataModel([1, 0, 1]) + with pytest.raises(exceptions.MisconfigurationException): + ArrayDataset(features, target) From 3d33b9db1f31eb0491fdd7606bd75b981958c0f8 Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Wed, 24 Aug 2022 22:41:45 +0100 Subject: [PATCH 2/8] added tests for DataModel --- tests/datasets/test_base_dataset.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/datasets/test_base_dataset.py diff --git a/tests/datasets/test_base_dataset.py b/tests/datasets/test_base_dataset.py new file mode 100644 index 0000000000..a641eca833 --- /dev/null +++ b/tests/datasets/test_base_dataset.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from pl_bolts.datasets.base_dataset import DataModel + + +def add_two(integers: np.ndarray) -> np.ndarray: + output = [] + for data in integers: + output.append(data + 2) + return np.array(output) + + +class TestDataModel: + @pytest.fixture + def data(self): + return np.array([[1, 0, 0, 1], [0, 1, 1, 0]]) + + def test_process_transform_is_none(self, data): + dm = DataModel(data=data) + np.testing.assert_array_equal(dm.process(data[0]), data[0]) + np.testing.assert_array_equal(dm.process(data[1]), data[1]) + + def test_process_transform_is_not_none(self, data): + dm = DataModel(data=data, transform=add_two) + np.testing.assert_array_equal(dm.process(data[0]), np.array([3, 2, 2, 3])) + np.testing.assert_array_equal(dm.process(data[1]), np.array([2, 3, 3, 2])) From 88b14961eccf8c3a0aa017de27300fc9cb19486d Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Thu, 25 Aug 2022 22:06:32 +0100 Subject: [PATCH 3/8] Renamed to TArrays --- pl_bolts/datasets/array_dataset.py | 6 +++--- pl_bolts/datasets/base_dataset.py | 10 +++++----- pl_bolts/datasets/utils.py | 14 +++++++------- pl_bolts/utils/types.py | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pl_bolts/datasets/array_dataset.py b/pl_bolts/datasets/array_dataset.py index c4babc2d3a..6795add38b 100644 --- a/pl_bolts/datasets/array_dataset.py +++ b/pl_bolts/datasets/array_dataset.py @@ -3,7 +3,7 @@ from pytorch_lightning.utilities import exceptions from torch.utils.data import Dataset -from pl_bolts.datasets.base_dataset import ARRAYS, DataModel +from pl_bolts.datasets.base_dataset import DataModel, TArrays class ArrayDataset(Dataset): @@ -40,7 +40,7 @@ def __init__(self, *data_models: DataModel) -> None: def __len__(self) -> int: return len(self.data_models[0].data) - def __getitem__(self, idx: int) -> Tuple[ARRAYS, ...]: + def __getitem__(self, idx: int) -> Tuple[TArrays, ...]: return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models) def _equal_size(self) -> bool: @@ -49,4 +49,4 @@ def _equal_size(self) -> bool: Returns: bool: True if size of data_models are equal in the first dimension. False, if not. """ - return all(len(data_model.data) == len(self.data_models[0].data) for data_model in self.data_models) + return len({len(data_model.data) for data_model in self.data_models}) == 1 diff --git a/pl_bolts/datasets/base_dataset.py b/pl_bolts/datasets/base_dataset.py index 2a0ef9986a..62a7811c8b 100644 --- a/pl_bolts/datasets/base_dataset.py +++ b/pl_bolts/datasets/base_dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import Dataset from pl_bolts.utils.stability import under_review -from pl_bolts.utils.types import ARRAYS +from pl_bolts.utils.types import TArrays @under_review() @@ -70,13 +70,13 @@ class DataModel: Attributes: data: Sequence of indexables. - transform: Callable to transform data. + transform: Callable to transform data. The transform is called on a subset of data. """ - data: ARRAYS - transform: Optional[Callable] = None + data: TArrays + transform: Optional[Callable[[TArrays], TArrays]] = None - def process(self, subset: ARRAYS) -> ARRAYS: + def process(self, subset: TArrays) -> TArrays: """Transforms a subset of data. Args: diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py index 2a4ef11d3b..3c0214bc21 100644 --- a/pl_bolts/datasets/utils.py +++ b/pl_bolts/datasets/utils.py @@ -1,5 +1,3 @@ -from typing import List - import torch from torch.utils.data.dataset import random_split @@ -7,6 +5,7 @@ from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.datasets.sr_stl10_dataset import SRSTL10 from pl_bolts.utils.stability import under_review +from pl_bolts.utils.types import TArrays @under_review() @@ -44,14 +43,15 @@ def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): return (dataset_train, dataset_val, dataset_test) -def to_tensor(integers: List[int]) -> torch.Tensor: - """Takes a list of integers and returns a tensor. +def to_tensor(arrays: TArrays) -> torch.Tensor: + """Takes a sequence of type `TArrays` and returns a tensor. This function serves as a use case for the ArrayDataset. Args: - integers: List of integers + arrays: Sequence of type `TArrays` + Returns: - A tensor of the integers + Tensor of the integers """ - return torch.tensor(integers) + return torch.tensor(arrays) diff --git a/pl_bolts/utils/types.py b/pl_bolts/utils/types.py index c291a12c22..ba46e17a07 100644 --- a/pl_bolts/utils/types.py +++ b/pl_bolts/utils/types.py @@ -1,6 +1,6 @@ -from typing import List, Union +from typing import Sequence, Union import numpy as np import torch -ARRAYS = Union[torch.Tensor, np.ndarray, List[Union[float, int]], List[List[Union[float, int]]]] +TArrays = Union[torch.Tensor, np.ndarray, Sequence[Union[float, int]], Sequence["TArrays"]] From 46ebb251c6c5ea5c04c224b871ddec56fea74da3 Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Thu, 25 Aug 2022 22:06:57 +0100 Subject: [PATCH 4/8] removed auxiliary functions --- tests/datasets/test_array_dataset.py | 24 +++++------------------- tests/datasets/test_base_dataset.py | 15 +++++---------- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/tests/datasets/test_array_dataset.py b/tests/datasets/test_array_dataset.py index 2c713a0009..9f1b0dba26 100644 --- a/tests/datasets/test_array_dataset.py +++ b/tests/datasets/test_array_dataset.py @@ -7,24 +7,10 @@ from pl_bolts.datasets.utils import to_tensor -def add_one(integers: np.ndarray) -> np.ndarray: - output = [] - for data in integers: - output.append(data + 1) - return np.array(output) - - -def add_one_to_tensor(integers: np.ndarray) -> torch.Tensor: - integers = add_one(integers) - return to_tensor(integers) - - class TestArrayDataset: @pytest.fixture def array_dataset(self): - features_1 = DataModel( - data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=add_one_to_tensor - ) + features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor) target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor) features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]])) @@ -39,19 +25,19 @@ def test_getitem_with_transforms(self, array_dataset): assert len(array_dataset[1]) == 4 assert len(array_dataset[2]) == 4 assert len(array_dataset[3]) == 4 - torch.testing.assert_close(array_dataset[0][0], torch.tensor([2, 1, 0, 3])) + torch.testing.assert_close(array_dataset[0][0], torch.tensor([1, 0, -1, 2])) torch.testing.assert_close(array_dataset[0][1], torch.tensor(1)) np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1])) assert array_dataset[0][3] == 1 - torch.testing.assert_close(array_dataset[1][0], torch.tensor([2, 1, -1, -0])) + torch.testing.assert_close(array_dataset[1][0], torch.tensor([1, 0, -2, -1])) torch.testing.assert_close(array_dataset[1][1], torch.tensor(0)) np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1])) assert array_dataset[1][3] == 0 - torch.testing.assert_close(array_dataset[2][0], torch.tensor([3, 6, 1, 4])) + torch.testing.assert_close(array_dataset[2][0], torch.tensor([2, 5, 0, 3])) torch.testing.assert_close(array_dataset[2][1], torch.tensor(0)) np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3])) assert array_dataset[2][3] == 1 - torch.testing.assert_close(array_dataset[3][0], torch.tensor([-6, 2, 3, 3])) + torch.testing.assert_close(array_dataset[3][0], torch.tensor([-7, 1, 2, 2])) torch.testing.assert_close(array_dataset[3][1], torch.tensor(1)) np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2])) assert array_dataset[3][3] == 1 diff --git a/tests/datasets/test_base_dataset.py b/tests/datasets/test_base_dataset.py index a641eca833..d794ce8035 100644 --- a/tests/datasets/test_base_dataset.py +++ b/tests/datasets/test_base_dataset.py @@ -1,14 +1,9 @@ import numpy as np import pytest +import torch from pl_bolts.datasets.base_dataset import DataModel - - -def add_two(integers: np.ndarray) -> np.ndarray: - output = [] - for data in integers: - output.append(data + 2) - return np.array(output) +from pl_bolts.datasets.utils import to_tensor class TestDataModel: @@ -22,6 +17,6 @@ def test_process_transform_is_none(self, data): np.testing.assert_array_equal(dm.process(data[1]), data[1]) def test_process_transform_is_not_none(self, data): - dm = DataModel(data=data, transform=add_two) - np.testing.assert_array_equal(dm.process(data[0]), np.array([3, 2, 2, 3])) - np.testing.assert_array_equal(dm.process(data[1]), np.array([2, 3, 3, 2])) + dm = DataModel(data=data, transform=to_tensor) + torch.testing.assert_close(dm.process(data[0]), torch.tensor([1, 0, 0, 1])) + torch.testing.assert_close(dm.process(data[1]), torch.tensor([0, 1, 1, 0])) From b0ebc0a61802093b89407a1126632de6ab296fca Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Thu, 25 Aug 2022 22:07:12 +0100 Subject: [PATCH 5/8] added to_tensor test --- tests/datasets/test_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/datasets/test_utils.py diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py new file mode 100644 index 0000000000..a517afe668 --- /dev/null +++ b/tests/datasets/test_utils.py @@ -0,0 +1,22 @@ +import numpy as np +import torch.testing + +from pl_bolts.datasets.utils import to_tensor + + +class TestToTensor: + def test_to_tensor_list(self): + _list = [1, 2, 3] + torch.testing.assert_close(to_tensor(_list), torch.tensor(_list)) + + def test_to_tensor_array(self): + _array = np.array([1, 2, 3]) + torch.testing.assert_close(to_tensor(_array), torch.tensor(_array)) + + def test_to_tensor_sequence(self): + _sequence = [1.0, 2.0, 3.0] + torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) + + def test_to_tensor_sequence_(self): + _sequence = [[1.0, 2.0, 3.0]] + torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) From bbee971836829ade39b09e08fd9fb3fbc724ea2e Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Thu, 25 Aug 2022 22:14:36 +0100 Subject: [PATCH 6/8] removed duplicate test --- tests/datasets/test_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index a517afe668..8f47578544 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -13,10 +13,6 @@ def test_to_tensor_array(self): _array = np.array([1, 2, 3]) torch.testing.assert_close(to_tensor(_array), torch.tensor(_array)) - def test_to_tensor_sequence(self): - _sequence = [1.0, 2.0, 3.0] - torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) - def test_to_tensor_sequence_(self): _sequence = [[1.0, 2.0, 3.0]] torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) From 05837be1405beb916a157d127254cf871ae3975d Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Fri, 26 Aug 2022 10:27:35 +0100 Subject: [PATCH 7/8] added float type --- pl_bolts/datasets/array_dataset.py | 4 ++-- pl_bolts/datasets/base_dataset.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datasets/array_dataset.py b/pl_bolts/datasets/array_dataset.py index 6795add38b..8251d3d9a5 100644 --- a/pl_bolts/datasets/array_dataset.py +++ b/pl_bolts/datasets/array_dataset.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Union from pytorch_lightning.utilities import exceptions from torch.utils.data import Dataset @@ -40,7 +40,7 @@ def __init__(self, *data_models: DataModel) -> None: def __len__(self) -> int: return len(self.data_models[0].data) - def __getitem__(self, idx: int) -> Tuple[TArrays, ...]: + def __getitem__(self, idx: int) -> Tuple[Union[TArrays, float], ...]: return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models) def _equal_size(self) -> bool: diff --git a/pl_bolts/datasets/base_dataset.py b/pl_bolts/datasets/base_dataset.py index 62a7811c8b..a3ccf50cbb 100644 --- a/pl_bolts/datasets/base_dataset.py +++ b/pl_bolts/datasets/base_dataset.py @@ -3,7 +3,7 @@ import urllib.request from abc import ABC from dataclasses import dataclass -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple, Union from urllib.error import HTTPError from torch import Tensor @@ -76,7 +76,7 @@ class DataModel: data: TArrays transform: Optional[Callable[[TArrays], TArrays]] = None - def process(self, subset: TArrays) -> TArrays: + def process(self, subset: Union[TArrays, float]) -> Union[TArrays, float]: """Transforms a subset of data. Args: From e5f188844d3634dbab5a5e057c755e6917f92fcb Mon Sep 17 00:00:00 2001 From: Cellan Hall Date: Fri, 26 Aug 2022 10:29:13 +0100 Subject: [PATCH 8/8] added float type --- pl_bolts/utils/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/utils/types.py b/pl_bolts/utils/types.py index ba46e17a07..9f85bb6fad 100644 --- a/pl_bolts/utils/types.py +++ b/pl_bolts/utils/types.py @@ -3,4 +3,4 @@ import numpy as np import torch -TArrays = Union[torch.Tensor, np.ndarray, Sequence[Union[float, int]], Sequence["TArrays"]] +TArrays = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence["TArrays"]] # type: ignore