Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning when data augmentation is applied on the pool #229

Merged
merged 11 commits into from
Oct 2, 2022
31 changes: 30 additions & 1 deletion baal/active/dataset/pytorch_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import warnings
from copy import deepcopy
from itertools import zip_longest
from typing import Union, Optional, Callable, Any, Dict, List
from typing import Union, Optional, Callable, Any, Dict, List, Sequence, Mapping

import numpy as np
import torch
import torch.utils.data as torchdata
from torch import Tensor

from baal.active.dataset.base import SplittedDataset, Dataset
from baal.utils.equality import deep_check

STOCHASTIC_POOL_WARNING = """
Data augmentation does not looks disabled when iterating on the pool.
Dref360 marked this conversation as resolved.
Show resolved Hide resolved
You can disable it by overriding attributes using `pool_specifics`
when instantiating ActiveLearningDataset.
Example:
```
from torchvision.transforms import *
train_transform = Compose([Resize((224, 224)), RandomHorizontalFlip(),
RandomRotation(30), ToTensor()])
test_transform = Compose([Resize((224, 224)),ToTensor()])
dataset = CIFAR10(..., transform=train_transform)

al_dataset = ActiveLearningDataset(dataset,
pool_specifics={'transform': test_transform})
```
"""


def _identity(x):
Expand Down Expand Up @@ -57,6 +77,7 @@ def __init__(
super().__init__(
labelled=labelled_map, random_state=random_state, last_active_steps=last_active_steps
)
self._warn_if_pool_stochastic()

def check_dataset_can_label(self):
"""Check if a dataset can be labelled.
Expand Down Expand Up @@ -199,6 +220,14 @@ def load_state_dict(self, state_dict):
self.labelled_map = state_dict["labelled"]
self.random_state = state_dict["random_state"]

def _warn_if_pool_stochastic(self):
pool = self.pool
if len(pool) > 0 and not deep_check(pool[0], pool[0]):
warnings.warn(
STOCHASTIC_POOL_WARNING,
UserWarning,
)


class ActiveLearningPool(torchdata.Dataset):
"""A dataset that represents the unlabelled pool for active learning.
Expand Down
22 changes: 22 additions & 0 deletions baal/utils/equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Sequence, Mapping

import numpy as np
import torch
from torch import Tensor


def deep_check(obj1, obj2) -> bool:
if type(obj1) != type(obj2):
return False
elif isinstance(obj1, str):
return bool(obj1 == obj2)
elif isinstance(obj1, Sequence):
return all(deep_check(i1, i2) for i1, i2 in zip(obj1, obj2))
elif isinstance(obj1, Mapping):
return all(deep_check(val1, obj2[key1]) for key1, val1 in obj1.items())
elif isinstance(obj1, Tensor):
return torch.equal(obj1, obj2)
elif isinstance(obj1, np.ndarray):
return bool((obj1 == obj2).all())
else:
return bool(obj1 == obj2)
13 changes: 12 additions & 1 deletion tests/active/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def test_transform(self):
with pytest.raises(ValueError) as e:
ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool

# Test warnings related to stochasticity of the pool.
with warnings.catch_warnings(record=True) as w:
_ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand())))
assert (len(w) > 0 and issubclass(w[-1].category, UserWarning)
and "Data augmentation does not looks disabled" in str(w[-1].message))

with warnings.catch_warnings(record=True) as w:
_ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand())),
pool_specifics={'transform': test_transform})
assert len(w) == 0

def test_random(self):
self.dataset.label_randomly(50)
assert len(self.dataset) == 50
Expand Down Expand Up @@ -201,7 +212,7 @@ def __len__(self):
return len(self.x)

def __getitem__(self, item):
return self.x[item], self.y[item]
return self.x[item], self.label[item]

with warnings.catch_warnings(record=True) as w:
al = ActiveLearningDataset(DS())
Expand Down
33 changes: 33 additions & 0 deletions tests/utils/test_equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import namedtuple

import numpy as np
import torch

from baal.utils.equality import deep_check

Point = namedtuple('Point', 'x,y')


def test_deep_check():
arr1, arr2 = np.random.rand(10), np.random.rand(10)
tensor1, tensor2 = torch.rand([10]), torch.rand([10])
s1, s2 = "string1", "string2"
p1, p2 = Point(x=1, y=2), Point(x=2, y=1)

assert not deep_check(arr1, arr2)
assert not deep_check(tensor1, tensor2)
assert not deep_check(s1, s2)
assert not deep_check(p1, p2)
assert not deep_check([arr1, tensor1], [arr2, tensor2])
assert not deep_check([arr1, tensor1], (arr1, tensor1))
assert not deep_check([arr1, tensor1], [tensor1, arr1])
assert not deep_check({'x': arr1, 'y': tensor1}, {'x': arr2, 'y': tensor2})
assert not deep_check({'x': arr1, 'y': tensor1}, {'x': tensor1, 'y': arr1})

assert deep_check(arr1, arr1)
assert deep_check(tensor1, tensor1)
assert deep_check(s1, s1)
assert deep_check(p1, p1)
assert deep_check([arr1, tensor1], [arr1, tensor1])
assert deep_check((arr1, tensor1), (arr1, tensor1))
assert deep_check({'x': arr1, 'y': tensor1}, {'x': arr1, 'y': tensor1})