Skip to content

Commit

Permalink
Add warning when data augmentation is applied on the pool (#229)
Browse files Browse the repository at this point in the history
Co-authored-by: Parmida Atighehchian <[email protected]>
  • Loading branch information
Dref360 and parmidaatg authored Oct 2, 2022
1 parent deafffa commit 19fbbb4
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 4 deletions.
31 changes: 30 additions & 1 deletion baal/active/dataset/pytorch_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import warnings
from copy import deepcopy
from itertools import zip_longest
from typing import Union, Optional, Callable, Any, Dict, List
from typing import Union, Optional, Callable, Any, Dict, List, Sequence, Mapping

import numpy as np
import torch
import torch.utils.data as torchdata
from torch import Tensor

from baal.active.dataset.base import SplittedDataset, Dataset
from baal.utils.equality import deep_check

STOCHASTIC_POOL_WARNING = """
It seems that data augmentation is not disabled when iterating on the pool.
You can disable it by overriding attributes using `pool_specifics`
when instantiating ActiveLearningDataset.
Example:
```
from torchvision.transforms import *
train_transform = Compose([Resize((224, 224)), RandomHorizontalFlip(),
RandomRotation(30), ToTensor()])
test_transform = Compose([Resize((224, 224)),ToTensor()])
dataset = CIFAR10(..., transform=train_transform)
al_dataset = ActiveLearningDataset(dataset,
pool_specifics={'transform': test_transform})
```
"""


def _identity(x):
Expand Down Expand Up @@ -57,6 +77,7 @@ def __init__(
super().__init__(
labelled=labelled_map, random_state=random_state, last_active_steps=last_active_steps
)
self._warn_if_pool_stochastic()

def check_dataset_can_label(self):
"""Check if a dataset can be labelled.
Expand Down Expand Up @@ -199,6 +220,14 @@ def load_state_dict(self, state_dict):
self.labelled_map = state_dict["labelled"]
self.random_state = state_dict["random_state"]

def _warn_if_pool_stochastic(self):
pool = self.pool
if len(pool) > 0 and not deep_check(pool[0], pool[0]):
warnings.warn(
STOCHASTIC_POOL_WARNING,
UserWarning,
)


class ActiveLearningPool(torchdata.Dataset):
"""A dataset that represents the unlabelled pool for active learning.
Expand Down
22 changes: 22 additions & 0 deletions baal/utils/equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Sequence, Mapping

import numpy as np
import torch
from torch import Tensor


def deep_check(obj1, obj2) -> bool:
if type(obj1) != type(obj2):
return False
elif isinstance(obj1, str):
return bool(obj1 == obj2)
elif isinstance(obj1, Sequence):
return all(deep_check(i1, i2) for i1, i2 in zip(obj1, obj2))
elif isinstance(obj1, Mapping):
return all(deep_check(val1, obj2[key1]) for key1, val1 in obj1.items())
elif isinstance(obj1, Tensor):
return torch.equal(obj1, obj2)
elif isinstance(obj1, np.ndarray):
return bool((obj1 == obj2).all())
else:
return bool(obj1 == obj2)
4 changes: 2 additions & 2 deletions notebooks/compatibility/nlp_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [\u001b[32minfo ] Start Predict dataset=67249\n"
"[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [\u001B[32minfo ] Start Predict dataset=67249\n"
]
},
{
Expand Down Expand Up @@ -394,7 +394,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [\u001b[32minfo ] Start Predict dataset=67239\n"
"[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [\u001B[32minfo ] Start Predict dataset=67239\n"
]
},
{
Expand Down
13 changes: 12 additions & 1 deletion tests/active/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def test_transform(self):
with pytest.raises(ValueError) as e:
ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool

# Test warnings related to stochasticity of the pool.
with warnings.catch_warnings(record=True) as w:
_ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand())))
assert (len(w) > 0 and issubclass(w[-1].category, UserWarning)
and "It seems that data augmentation is not disabled when iterating on the pool." in str(w[-1].message))

with warnings.catch_warnings(record=True) as w:
_ = ActiveLearningDataset(MyDataset(Lambda(lambda k: np.random.rand())),
pool_specifics={'transform': test_transform})
assert len(w) == 0

def test_random(self):
self.dataset.label_randomly(50)
assert len(self.dataset) == 50
Expand Down Expand Up @@ -201,7 +212,7 @@ def __len__(self):
return len(self.x)

def __getitem__(self, item):
return self.x[item], self.y[item]
return self.x[item], self.label[item]

with warnings.catch_warnings(record=True) as w:
al = ActiveLearningDataset(DS())
Expand Down
33 changes: 33 additions & 0 deletions tests/utils/test_equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import namedtuple

import numpy as np
import torch

from baal.utils.equality import deep_check

Point = namedtuple('Point', 'x,y')


def test_deep_check():
arr1, arr2 = np.random.rand(10), np.random.rand(10)
tensor1, tensor2 = torch.rand([10]), torch.rand([10])
s1, s2 = "string1", "string2"
p1, p2 = Point(x=1, y=2), Point(x=2, y=1)

assert not deep_check(arr1, arr2)
assert not deep_check(tensor1, tensor2)
assert not deep_check(s1, s2)
assert not deep_check(p1, p2)
assert not deep_check([arr1, tensor1], [arr2, tensor2])
assert not deep_check([arr1, tensor1], (arr1, tensor1))
assert not deep_check([arr1, tensor1], [tensor1, arr1])
assert not deep_check({'x': arr1, 'y': tensor1}, {'x': arr2, 'y': tensor2})
assert not deep_check({'x': arr1, 'y': tensor1}, {'x': tensor1, 'y': arr1})

assert deep_check(arr1, arr1)
assert deep_check(tensor1, tensor1)
assert deep_check(s1, s1)
assert deep_check(p1, p1)
assert deep_check([arr1, tensor1], [arr1, tensor1])
assert deep_check((arr1, tensor1), (arr1, tensor1))
assert deep_check({'x': arr1, 'y': tensor1}, {'x': arr1, 'y': tensor1})

0 comments on commit 19fbbb4

Please sign in to comment.