Skip to content

Commit

Permalink
#205 Add stochastic heuristics from Kirsch et al.
Browse files Browse the repository at this point in the history
  • Loading branch information
fr.branchaud-charron committed Apr 21, 2022
1 parent 38afad8 commit 32b9034
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 49 deletions.
137 changes: 137 additions & 0 deletions baal/active/heuristics/stochastics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import types

import numpy as np
import structlog
from scipy.special import softmax
from scipy.stats import rankdata

from baal.active.heuristics import AbstractHeuristic, Sequence

log = structlog.get_logger(__name__)
EPSILON = 1e-8


class StochasticHeuristic(AbstractHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size):
"""Heuristic that is stochastic to improve diversity.
Common acquisition functions are heavily impacted by duplicates. When using a `top-k` approache where the most
uncertain examples are selected, the acquisition function can select many duplicates.
Techniques such as BADGE (Ash et al, 2019) or BatchBALD (Kirsh et al. 2019)
are common solutions to this problem, but they are quite expensive.
Stochastic acquisitions are cheap to compute and get similar performances.
References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059
Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
"""
# TODO handle reverse
super().__init__(reverse=False)
self._bh = base_heuristic
self.query_size = query_size

def get_ranks(self, predictions):
# Get the raw uncertainty from the base heuristic.
scores = self.get_scores(predictions)
# Create the distribution to sample from.
distributions = self._make_distribution(scores)
# Force normalization for np.random.choice
distributions = np.clip(distributions, 0)
distributions /= distributions.sum()

# TODO Seed?
if (distributions > 0).sum() < self.query_size:
log.warnings("Not enough values, return random")
distributions = np.ones_like(distributions) / len(distributions)
return (
np.random.choice(len(distributions), self.query_size, replace=False, p=distributions),
distributions,
)

def get_scores(self, predictions):
if isinstance(predictions, types.GeneratorType):
scores = self._bh.get_uncertainties_generator(predictions)
else:
scores = self._bh.get_uncertainties(predictions)
if isinstance(scores, Sequence):
scores = np.concatenate(scores)
return scores

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
raise NotImplementedError


class PowerSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the uncertainty distribution without modification beside temperature scaling and normalization.
Stochastic heuristic that assumes that the uncertainty distribution
is positive and that items with near-zero uncertainty are uninformative.
Empirically worked the best in the paper.
References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059
Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
scores = scores ** (1 / self.temperature)
scores = scores / scores.sum()
return scores


class GibbsSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the uncertainty distribution after applying softmax.
References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059
Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
scores /= self.temperature
# scores dimensions is [N]
scores = softmax(scores)
return scores


class RankBasedSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the ranks of the uncertainty distribution.
References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059
Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
rank = rankdata(-scores)
weights = rank ** (-1 / self.temperature)
return weights / weights.sum()
56 changes: 7 additions & 49 deletions tests/active/heuristic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,22 @@
Precomputed,
CombineHeuristics,
)
from tests.test_utils import make_fake_dist, make_3d_fake_dist, make_5d_fake_dist

N_ITERATIONS = 50
IMG_SIZE = 3
N_CLASS = 10


def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i: i + n]


def _make_3d_fake_dist(means, stds, dims=10):
d = np.stack(
[_make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)]
) # 50 iterations
d = np.rollaxis(d, 0, 3)
# [n_sample, n_class, n_iter]
return d


def _make_5d_fake_dist(means, stds, dims=10):
d = np.stack(
[_make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1
) # 3x3 image
b, c, i, hw = d.shape
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE])
d = np.rollaxis(d, 2, 5)
# [n_sample, n_class, H, W, iter]
return d


def _make_fake_dist(means, stds, dims=10):
"""
Create some fake discrete distributions
Args:
means: List of means
stds: List of standard deviations
dims: Dimensions of the distributions
Returns:
List of distributions
"""
n_trials = 100
distributions = []
for m, std in zip(means, stds):
dist = np.zeros([dims])
for i in range(n_trials):
dist[
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item()
] += 1
distributions.append(dist / n_trials)
return np.array(distributions)


distribution_2d = _make_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_3d = _make_3d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_5d = _make_5d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)



distribution_2d = make_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_3d = make_3d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_5d = make_5d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)


@pytest.mark.parametrize(
Expand Down
25 changes: 25 additions & 0 deletions tests/active/stochastic_heuristic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import pytest
from scipy.stats import entropy

from baal.active.heuristics import BALD, Entropy
from baal.active.heuristics.stochastics import GibbsSampling, RankBasedSampling, PowerSampling
from tests.test_utils import make_fake_dist


@pytest.fixture
def sampled_predictions():
return np.stack([make_fake_dist([1, 2, 2], [1, 3, 3], dims=20) for _ in range(10)])


@pytest.mark.parametrize("stochastic_heuristic", [GibbsSampling, RankBasedSampling, PowerSampling])
@pytest.mark.parametrize("base_heuristic", [BALD, Entropy])
def test_stochastic_heuristic(stochastic_heuristic, base_heuristic, sampled_predictions):
heur_temp_1 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=1.0)
heur_temp_10 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=10.0)

scores = heur_temp_1.get_scores(sampled_predictions)

dist_temp_1, dist_temp_10 = heur_temp_1._make_distribution(scores), heur_temp_10._make_distribution(scores)

assert entropy(dist_temp_1) > entropy(dist_temp_10)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def fn(module: nn.Module, input_shape):
pred1 = module(inp).detach().cpu().numpy()
return all(np.allclose(pred1, module(inp).detach().cpu().numpy()) for _ in range(5))
return fn


@pytest.fixture
def sampled_predictions():
return np.random.randn(100, 10, 20)
47 changes: 47 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

N_ITERATIONS = 50
IMG_SIZE = 3


def make_3d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)]
) # 50 iterations
d = np.rollaxis(d, 0, 3)
# [n_sample, n_class, n_iter]
return d


def make_5d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1
) # 3x3 image
b, c, i, hw = d.shape
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE])
d = np.rollaxis(d, 2, 5)
# [n_sample, n_class, H, W, iter]
return d


def make_fake_dist(means, stds, dims=10):
"""
Create some fake discrete distributions
Args:
means: List of means
stds: List of standard deviations
dims: Dimensions of the distributions
Returns:
List of distributions
"""
n_trials = 100
distributions = []
for m, std in zip(means, stds):
dist = np.zeros([dims])
for i in range(n_trials):
dist[
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item()
] += 1
distributions.append(dist / n_trials)
return np.array(distributions)

0 comments on commit 32b9034

Please sign in to comment.