-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: otaj <[email protected]>
- Loading branch information
Showing
10 changed files
with
202 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Tuple, Union | ||
|
||
from pytorch_lightning.utilities import exceptions | ||
from torch.utils.data import Dataset | ||
|
||
from pl_bolts.datasets.base_dataset import DataModel, TArrays | ||
|
||
|
||
class ArrayDataset(Dataset): | ||
"""Dataset wrapping tensors, lists, numpy arrays. | ||
Any number of ARRAYS can be inputted into the dataset. The ARRAYS are transformed on each `__getitem__`. When | ||
transforming, please refrain from chaning the hape of ARRAYS in the first demension. | ||
Attributes: | ||
data_models: Sequence of data models. | ||
Raises: | ||
MisconfigurationException: if there is a shape mismatch between arrays in the first dimension. | ||
Example: | ||
>>> from pl_bolts.datasets import ArrayDataset, DataModel | ||
>>> from pl_bolts.datasets.utils import to_tensor | ||
>>> features = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3]], transform=to_tensor) | ||
>>> target = DataModel(data=[1, 0, 0], transform=to_tensor) | ||
>>> ds = ArrayDataset(features, target) | ||
>>> len(ds) | ||
3 | ||
""" | ||
|
||
def __init__(self, *data_models: DataModel) -> None: | ||
"""Initialises class and checks if arrays are the same shape in the first dimension.""" | ||
self.data_models = data_models | ||
|
||
if not self._equal_size(): | ||
raise exceptions.MisconfigurationException("Shape mismatch between arrays in the first dimension") | ||
|
||
def __len__(self) -> int: | ||
return len(self.data_models[0].data) | ||
|
||
def __getitem__(self, idx: int) -> Tuple[Union[TArrays, float], ...]: | ||
return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models) | ||
|
||
def _equal_size(self) -> bool: | ||
"""Checks the size of the data_models are equal in the first dimension. | ||
Returns: | ||
bool: True if size of data_models are equal in the first dimension. False, if not. | ||
""" | ||
return len({len(data_model.data) for data_model in self.data_models}) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from typing import Sequence, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
TArrays = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence["TArrays"]] # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from pytorch_lightning.utilities import exceptions | ||
|
||
from pl_bolts.datasets import ArrayDataset, DataModel | ||
from pl_bolts.datasets.utils import to_tensor | ||
|
||
|
||
class TestArrayDataset: | ||
@pytest.fixture | ||
def array_dataset(self): | ||
features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor) | ||
target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor) | ||
|
||
features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]])) | ||
target_2 = DataModel(data=[1, 0, 1, 1]) | ||
return ArrayDataset(features_1, target_1, features_2, target_2) | ||
|
||
def test_len(self, array_dataset): | ||
assert len(array_dataset) == 4 | ||
|
||
def test_getitem_with_transforms(self, array_dataset): | ||
assert len(array_dataset[0]) == 4 | ||
assert len(array_dataset[1]) == 4 | ||
assert len(array_dataset[2]) == 4 | ||
assert len(array_dataset[3]) == 4 | ||
torch.testing.assert_close(array_dataset[0][0], torch.tensor([1, 0, -1, 2])) | ||
torch.testing.assert_close(array_dataset[0][1], torch.tensor(1)) | ||
np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1])) | ||
assert array_dataset[0][3] == 1 | ||
torch.testing.assert_close(array_dataset[1][0], torch.tensor([1, 0, -2, -1])) | ||
torch.testing.assert_close(array_dataset[1][1], torch.tensor(0)) | ||
np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1])) | ||
assert array_dataset[1][3] == 0 | ||
torch.testing.assert_close(array_dataset[2][0], torch.tensor([2, 5, 0, 3])) | ||
torch.testing.assert_close(array_dataset[2][1], torch.tensor(0)) | ||
np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3])) | ||
assert array_dataset[2][3] == 1 | ||
torch.testing.assert_close(array_dataset[3][0], torch.tensor([-7, 1, 2, 2])) | ||
torch.testing.assert_close(array_dataset[3][1], torch.tensor(1)) | ||
np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2])) | ||
assert array_dataset[3][3] == 1 | ||
|
||
def test__equal_size_true(self, array_dataset): | ||
assert array_dataset._equal_size() is True | ||
|
||
def test__equal_size_false(self): | ||
features = DataModel(data=[[1, 0, 1]]) | ||
target = DataModel([1, 0, 1]) | ||
with pytest.raises(exceptions.MisconfigurationException): | ||
ArrayDataset(features, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from pl_bolts.datasets.base_dataset import DataModel | ||
from pl_bolts.datasets.utils import to_tensor | ||
|
||
|
||
class TestDataModel: | ||
@pytest.fixture | ||
def data(self): | ||
return np.array([[1, 0, 0, 1], [0, 1, 1, 0]]) | ||
|
||
def test_process_transform_is_none(self, data): | ||
dm = DataModel(data=data) | ||
np.testing.assert_array_equal(dm.process(data[0]), data[0]) | ||
np.testing.assert_array_equal(dm.process(data[1]), data[1]) | ||
|
||
def test_process_transform_is_not_none(self, data): | ||
dm = DataModel(data=data, transform=to_tensor) | ||
torch.testing.assert_close(dm.process(data[0]), torch.tensor([1, 0, 0, 1])) | ||
torch.testing.assert_close(dm.process(data[1]), torch.tensor([0, 1, 1, 0])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import numpy as np | ||
import torch.testing | ||
|
||
from pl_bolts.datasets.utils import to_tensor | ||
|
||
|
||
class TestToTensor: | ||
def test_to_tensor_list(self): | ||
_list = [1, 2, 3] | ||
torch.testing.assert_close(to_tensor(_list), torch.tensor(_list)) | ||
|
||
def test_to_tensor_array(self): | ||
_array = np.array([1, 2, 3]) | ||
torch.testing.assert_close(to_tensor(_array), torch.tensor(_array)) | ||
|
||
def test_to_tensor_sequence_(self): | ||
_sequence = [[1.0, 2.0, 3.0]] | ||
torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) |