From 11adf9e1b30001f2c5a746b9ecc2192cbca4cd99 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sun, 21 Apr 2024 16:03:32 -0400 Subject: [PATCH 1/3] Add Stopping Criteria for loop --- baal/active/stopping_criteria.py | 63 ++++++++++++++++++++++++++++++++ experiments/mlp_mcdropout.py | 11 ++++-- 2 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 baal/active/stopping_criteria.py diff --git a/baal/active/stopping_criteria.py b/baal/active/stopping_criteria.py new file mode 100644 index 00000000..c47a29a1 --- /dev/null +++ b/baal/active/stopping_criteria.py @@ -0,0 +1,63 @@ +from typing import List, Dict + +import numpy as np + +from baal import ActiveLearningDataset + + +class StoppingCriterion: + def __init__(self, active_dataset: ActiveLearningDataset): + self._active_ds = active_dataset + + def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + raise NotImplementedError + + +class LabellingBudgetStoppingCriterion(StoppingCriterion): + """Stops when the labelling budget is exhausted.""" + + def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int): + super().__init__(active_dataset) + self._start_length = len(active_dataset) + self.labelling_budget = labelling_budget + + def should_stop(self, uncertainty: List[float]) -> bool: + return (len(self._active_ds) - self._start_length) >= self.labelling_budget + + +class LowAverageUncertaintyStoppingCriterion(StoppingCriterion): + """Stops when the average uncertainty is on average below a threshold.""" + + def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh: float): + super().__init__(active_dataset) + self.avg_uncertainty_thresh = avg_uncertainty_thresh + + def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + return np.mean(uncertainty) < self.avg_uncertainty_thresh + + +class EarlyStoppingCriterion(StoppingCriterion): + """Early stopping on a particular metrics. + + Notes: + We don't have any mandatory dependency with an early stopping implementation. + So we have our own. + """ + + def __init__( + self, + active_dataset: ActiveLearningDataset, + metric_name: str, + patience: int = 10, + epsilon: float = 1e-4, + ): + super().__init__(active_dataset) + self.metric_name = metric_name + self.patience = patience + self.epsilon = epsilon + self._acc = [] + + def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + self._acc.append(metrics[self.metric_name]) + near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon) + return len(near_threshold) > self.patience and near_threshold[-self.patience].all() diff --git a/experiments/mlp_mcdropout.py b/experiments/mlp_mcdropout.py index 34f3d4a0..32ab922c 100644 --- a/experiments/mlp_mcdropout.py +++ b/experiments/mlp_mcdropout.py @@ -9,6 +9,7 @@ from baal import ActiveLearningDataset, ModelWrapper from baal.active import ActiveLearningLoop from baal.active.heuristics import BALD +from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion from baal.bayesian.dropout import patch_module use_cuda = torch.cuda.is_available() @@ -54,8 +55,11 @@ # Following Gal 2016, we reset the weights at the beginning of each step. initial_weights = deepcopy(model.state_dict()) +stopping_criterion = LabellingBudgetStoppingCriterion( + active_dataset=al_dataset, labelling_budget=10 +) -for step in range(100): +while True: model.load_state_dict(initial_weights) train_loss = wrapper.train_on_dataset( al_dataset, optimizer=optimizer, batch_size=32, epoch=10, use_cuda=use_cuda @@ -63,7 +67,6 @@ test_loss = wrapper.test_on_dataset(test_ds, batch_size=32, use_cuda=use_cuda) pprint(wrapper.get_metrics()) - flag = al_loop.step() - if not flag: - # We are done labelling! stopping + al_loop.step() + if stopping_criterion.should_stop(): break From d807e1b356689c2b96278e4131674bee446bc40a Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 11 May 2024 12:54:26 -0400 Subject: [PATCH 2/3] Changes according to review --- baal/active/stopping_criteria.py | 12 ++++++------ experiments/mlp_mcdropout.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/baal/active/stopping_criteria.py b/baal/active/stopping_criteria.py index c47a29a1..ac9e6695 100644 --- a/baal/active/stopping_criteria.py +++ b/baal/active/stopping_criteria.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import Iterable, Dict import numpy as np @@ -9,7 +9,7 @@ class StoppingCriterion: def __init__(self, active_dataset: ActiveLearningDataset): self._active_ds = active_dataset - def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: raise NotImplementedError @@ -21,7 +21,7 @@ def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int) self._start_length = len(active_dataset) self.labelling_budget = labelling_budget - def should_stop(self, uncertainty: List[float]) -> bool: + def should_stop(self, uncertainty: Iterable[float]) -> bool: return (len(self._active_ds) - self._start_length) >= self.labelling_budget @@ -32,7 +32,7 @@ def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh super().__init__(active_dataset) self.avg_uncertainty_thresh = avg_uncertainty_thresh - def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: return np.mean(uncertainty) < self.avg_uncertainty_thresh @@ -57,7 +57,7 @@ def __init__( self.epsilon = epsilon self._acc = [] - def should_stop(self, metrics: Dict[str, float], uncertainty: List[float]) -> bool: + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: self._acc.append(metrics[self.metric_name]) near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon) - return len(near_threshold) > self.patience and near_threshold[-self.patience].all() + return len(near_threshold) >= self.patience and near_threshold[-(self.patience + 1) :].all() diff --git a/experiments/mlp_mcdropout.py b/experiments/mlp_mcdropout.py index 32ab922c..b7690522 100644 --- a/experiments/mlp_mcdropout.py +++ b/experiments/mlp_mcdropout.py @@ -67,6 +67,6 @@ test_loss = wrapper.test_on_dataset(test_ds, batch_size=32, use_cuda=use_cuda) pprint(wrapper.get_metrics()) - al_loop.step() - if stopping_criterion.should_stop(): + flag = al_loop.step() + if stopping_criterion.should_stop() or flag: break From f752a18712975ed21381f9fdfc250e4a786356f9 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 11 May 2024 13:02:51 -0400 Subject: [PATCH 3/3] Add tests --- tests/active/criterion_test.py | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/active/criterion_test.py diff --git a/tests/active/criterion_test.py b/tests/active/criterion_test.py new file mode 100644 index 00000000..a586d0ba --- /dev/null +++ b/tests/active/criterion_test.py @@ -0,0 +1,53 @@ +from baal.active.stopping_criteria import ( + LabellingBudgetStoppingCriterion, + EarlyStoppingCriterion, + LowAverageUncertaintyStoppingCriterion, +) +from baal.active.dataset import ActiveNumpyArray +import numpy as np + + +def test_labelling_budget(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + ds.label_randomly(10) + criterion = LabellingBudgetStoppingCriterion(ds, labelling_budget=50) + assert not criterion.should_stop([]) + + ds.label_randomly(10) + assert not criterion.should_stop([]) + + ds.label_randomly(40) + assert criterion.should_stop([]) + + +def test_early_stopping(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5) + + for i in range(10): + assert not criterion.should_stop( + metrics={"test_loss": 1 / (i + 1)}, uncertainty=[] + ) + + for _ in range(4): + assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + + # test less than patience stability + criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5) + for _ in range(4): + assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + + +def test_low_average(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + criterion = LowAverageUncertaintyStoppingCriterion( + active_dataset=ds, avg_uncertainty_thresh=0.1 + ) + assert not criterion.should_stop( + metrics={}, uncertainty=np.random.normal(0.5, scale=0.8, size=(100,)) + ) + assert criterion.should_stop( + metrics={}, uncertainty=np.random.normal(0.05, scale=0.01, size=(100,)) + )