diff --git a/baal/active/dataset/pytorch_dataset.py b/baal/active/dataset/pytorch_dataset.py index 2ba85fe6..311c07ae 100644 --- a/baal/active/dataset/pytorch_dataset.py +++ b/baal/active/dataset/pytorch_dataset.py @@ -1,12 +1,32 @@ import warnings from copy import deepcopy from itertools import zip_longest -from typing import Union, Optional, Callable, Any, Dict, List +from typing import Union, Optional, Callable, Any, Dict, List, Sequence, Mapping import numpy as np +import torch import torch.utils.data as torchdata +from torch import Tensor from baal.active.dataset.base import SplittedDataset, Dataset +from baal.utils.equality import deep_check + +STOCHASTIC_POOL_WARNING = """ +It seems that data augmentation is not disabled when iterating on the pool. +You can disable it by overriding attributes using `pool_specifics` +when instantiating ActiveLearningDataset. +Example: +``` +from torchvision.transforms import * +train_transform = Compose([Resize((224, 224)), RandomHorizontalFlip(), + RandomRotation(30), ToTensor()]) +test_transform = Compose([Resize((224, 224)),ToTensor()]) +dataset = CIFAR10(..., transform=train_transform) + +al_dataset = ActiveLearningDataset(dataset, + pool_specifics={'transform': test_transform}) +``` +""" def _identity(x): @@ -57,6 +77,7 @@ def __init__( super().__init__( labelled=labelled_map, random_state=random_state, last_active_steps=last_active_steps ) + self._warn_if_pool_stochastic() def check_dataset_can_label(self): """Check if a dataset can be labelled. @@ -199,6 +220,14 @@ def load_state_dict(self, state_dict): self.labelled_map = state_dict["labelled"] self.random_state = state_dict["random_state"] + def _warn_if_pool_stochastic(self): + pool = self.pool + if len(pool) > 0 and not deep_check(pool[0], pool[0]): + warnings.warn( + STOCHASTIC_POOL_WARNING, + UserWarning, + ) + class ActiveLearningPool(torchdata.Dataset): """A dataset that represents the unlabelled pool for active learning. diff --git a/baal/utils/equality.py b/baal/utils/equality.py new file mode 100644 index 00000000..86b71c63 --- /dev/null +++ b/baal/utils/equality.py @@ -0,0 +1,22 @@ +from typing import Sequence, Mapping + +import numpy as np +import torch +from torch import Tensor + + +def deep_check(obj1, obj2) -> bool: + if type(obj1) != type(obj2): + return False + elif isinstance(obj1, str): + return bool(obj1 == obj2) + elif isinstance(obj1, Sequence): + return all(deep_check(i1, i2) for i1, i2 in zip(obj1, obj2)) + elif isinstance(obj1, Mapping): + return all(deep_check(val1, obj2[key1]) for key1, val1 in obj1.items()) + elif isinstance(obj1, Tensor): + return torch.equal(obj1, obj2) + elif isinstance(obj1, np.ndarray): + return bool((obj1 == obj2).all()) + else: + return bool(obj1 == obj2) diff --git a/notebooks/compatibility/nlp_classification.ipynb b/notebooks/compatibility/nlp_classification.ipynb index e743c7a6..032b98d2 100644 --- a/notebooks/compatibility/nlp_classification.ipynb +++ b/notebooks/compatibility/nlp_classification.ipynb @@ -200,7 +200,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [\u001b[32minfo ] Start Predict dataset=67249\n" + "[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [\u001B[32minfo ] Start Predict dataset=67249\n" ] }, { @@ -394,7 +394,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [\u001b[32minfo ] Start Predict dataset=67239\n" + "[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [\u001B[32minfo ] Start Predict dataset=67239\n" ] }, { diff --git a/tests/active/dataset/dataset_test.py b/tests/active/dataset/dataset_test.py index 4ecc314f..b5f1e441 100644 --- a/tests/active/dataset/dataset_test.py +++ b/tests/active/dataset/dataset_test.py @@ -136,6 +136,17 @@ def test_transform(self): with pytest.raises(ValueError) as e: ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool + # Test warnings related to stochasticity of the pool. + with warnings.catch_warnings(record=True) as w: + _ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand()))) + assert (len(w) > 0 and issubclass(w[-1].category, UserWarning) + and "It seems that data augmentation is not disabled when iterating on the pool." in str(w[-1].message)) + + with warnings.catch_warnings(record=True) as w: + _ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand())), + pool_specifics={'transform': test_transform}) + assert len(w) == 0 + def test_random(self): self.dataset.label_randomly(50) assert len(self.dataset) == 50 @@ -201,7 +212,7 @@ def __len__(self): return len(self.x) def __getitem__(self, item): - return self.x[item], self.y[item] + return self.x[item], self.label[item] with warnings.catch_warnings(record=True) as w: al = ActiveLearningDataset(DS()) diff --git a/tests/utils/test_equality.py b/tests/utils/test_equality.py new file mode 100644 index 00000000..9ba86ad7 --- /dev/null +++ b/tests/utils/test_equality.py @@ -0,0 +1,33 @@ +from collections import namedtuple + +import numpy as np +import torch + +from baal.utils.equality import deep_check + +Point = namedtuple('Point', 'x,y') + + +def test_deep_check(): + arr1, arr2 = np.random.rand(10), np.random.rand(10) + tensor1, tensor2 = torch.rand([10]), torch.rand([10]) + s1, s2 = "string1", "string2" + p1, p2 = Point(x=1, y=2), Point(x=2, y=1) + + assert not deep_check(arr1, arr2) + assert not deep_check(tensor1, tensor2) + assert not deep_check(s1, s2) + assert not deep_check(p1, p2) + assert not deep_check([arr1, tensor1], [arr2, tensor2]) + assert not deep_check([arr1, tensor1], (arr1, tensor1)) + assert not deep_check([arr1, tensor1], [tensor1, arr1]) + assert not deep_check({'x': arr1, 'y': tensor1}, {'x': arr2, 'y': tensor2}) + assert not deep_check({'x': arr1, 'y': tensor1}, {'x': tensor1, 'y': arr1}) + + assert deep_check(arr1, arr1) + assert deep_check(tensor1, tensor1) + assert deep_check(s1, s1) + assert deep_check(p1, p1) + assert deep_check([arr1, tensor1], [arr1, tensor1]) + assert deep_check((arr1, tensor1), (arr1, tensor1)) + assert deep_check({'x': arr1, 'y': tensor1}, {'x': arr1, 'y': tensor1})