Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Args to ModelWrapper to simplify common API #294

Merged
merged 3 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ In conclusion, your script should be similar to this:
dataset = ActiveLearningDataset(your_dataset)
dataset.label_randomly(INITIAL_POOL) # label some data
model = MCDropoutModule(your_model)
model = ModelWrapper(model, your_criterion)
model = ModelWrapper(model, args=TrainingArgs(...))
active_loop = ActiveLearningLoop(dataset,
get_probabilities=model.predict_on_dataset,
heuristic=heuristics.BALD(),
iterations=20, # Number of MC sampling.
query_size=QUERY_SIZE) # Number of item to label.
for al_step in range(N_ALSTEP):
model.train_on_dataset(dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda)
metrics = model.test_on_dataset(test_dataset, BATCH_SIZE)
model.train_on_dataset(dataset)
metrics = model.test_on_dataset(test_dataset)
# Label the next most uncertain items.
if not active_loop.step():
# We're done!
Expand Down
6 changes: 4 additions & 2 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from typing import Union, List, Optional, Any, TYPE_CHECKING, Protocol
from typing import Union, List, Optional, Any, TYPE_CHECKING, Protocol, Tuple

import numpy as np
from sklearn.utils import check_random_state
from torch.utils import data as torchdata

from baal.utils.equality import assert_not_none


class SizeableDataset(torchdata.Dataset):
def __len__(self):
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
if last_active_steps == 0 or last_active_steps < -1:
raise ValueError("last_active_steps must be > 0 or -1 when disabled.")
self.last_active_steps = last_active_steps
self._indices_cache = (-1, None)
self._indices_cache: Tuple[int, List[int]] = (-1, [])

def get_indices_for_active_step(self) -> List[int]:
"""Returns the indices required for the active step.
Expand Down
28 changes: 4 additions & 24 deletions baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ class AbstractGPUHeuristic(ModelWrapper):
def __init__(
self,
model: ModelWrapper,
criterion,
shuffle_prop=0.0,
threshold=None,
reverse=False,
reduction="none",
):
super().__init__(model, criterion)
super().__init__(model, model.args)
self.shuffle_prop = shuffle_prop
self.threshold = threshold
self.reversed = reverse
Expand Down Expand Up @@ -102,32 +101,15 @@ def get_uncertainties(self, predictions):
def predict_on_dataset(
self,
dataset: Dataset,
batch_size: int,
iterations: int,
use_cuda: bool,
workers: int = 4,
collate_fn: Optional[Callable] = None,
half=False,
verbose=True,
):
return (
super()
.predict_on_dataset(
dataset,
batch_size,
iterations,
use_cuda,
workers,
collate_fn,
half,
verbose,
)
.reshape([-1])
)
return super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])

def predict_on_batch(self, data, iterations=1, use_cuda=False):
def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
return self.get_uncertainties(self.model.predict_on_batch(data, iterations, cuda=use_cuda))
return self.get_uncertainties(self.model.predict_on_batch(data, iterations))


class BALDGPUWrapper(AbstractGPUHeuristic):
Expand All @@ -139,14 +121,12 @@ class BALDGPUWrapper(AbstractGPUHeuristic):
def __init__(
self,
model: ModelWrapper,
criterion,
shuffle_prop=0.0,
threshold=None,
reduction="none",
):
super().__init__(
model,
criterion=criterion,
shuffle_prop=shuffle_prop,
threshold=threshold,
reverse=True,
Expand Down
13 changes: 8 additions & 5 deletions baal/active/stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Dict
from typing import Iterable, Dict, List

import numpy as np

Expand All @@ -21,7 +21,7 @@ def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int)
self._start_length = len(active_dataset)
self.labelling_budget = labelling_budget

def should_stop(self, uncertainty: Iterable[float]) -> bool:
def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
return (len(self._active_ds) - self._start_length) >= self.labelling_budget


Expand All @@ -33,7 +33,8 @@ def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh
self.avg_uncertainty_thresh = avg_uncertainty_thresh

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
return np.mean(uncertainty) < self.avg_uncertainty_thresh
arr = np.array(uncertainty)
return bool(np.mean(arr) < self.avg_uncertainty_thresh)


class EarlyStoppingCriterion(StoppingCriterion):
Expand All @@ -55,9 +56,11 @@ def __init__(
self.metric_name = metric_name
self.patience = patience
self.epsilon = epsilon
self._acc = []
self._acc: List[float] = []

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
self._acc.append(metrics[self.metric_name])
near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon)
return len(near_threshold) >= self.patience and near_threshold[-(self.patience + 1) :].all()
return len(near_threshold) >= self.patience and bool(
near_threshold[-(self.patience + 1) :].all()
)
48 changes: 19 additions & 29 deletions baal/calibration/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from typing import Optional

import structlog
import torch
Expand All @@ -7,6 +8,7 @@
from torch.optim import Adam

from baal import ModelWrapper
from baal.modelwrapper import TrainingArgs
from baal.utils.metrics import ECE, ECE_PerCLs

log = structlog.get_logger("Calibrating...")
Expand Down Expand Up @@ -37,6 +39,7 @@ class DirichletCalibrator(object):
reg_factor (float): Regularization factor for the linear layer weights.
mu (float): Regularization factor for the linear layer biases.
If not given, will be initialized by "l".
training_duration (int): How long to train calibration layer.

"""

Expand All @@ -46,7 +49,8 @@ def __init__(
num_classes: int,
lr: float,
reg_factor: float,
mu: float = None,
mu: Optional[float] = None,
training_duration: int = 5,
):
self.num_classes = num_classes
self.criterion = nn.CrossEntropyLoss()
Expand All @@ -55,7 +59,17 @@ def __init__(
self.mu = mu or reg_factor
self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes)
self.model = nn.Sequential(wrapper.model, self.dirichlet_linear)
self.wrapper = ModelWrapper(self.model, self.criterion)
self.optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)
self.wrapper = ModelWrapper(
self.model,
TrainingArgs(
criterion=self.criterion,
optimizer=self.optimizer,
regularizer=self.l2_reg,
epoch=training_duration,
use_cuda=wrapper.args.use_cuda,
),
)

self.wrapper.add_metric("ece", lambda: ECE())
self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))
Expand All @@ -75,8 +89,6 @@ def calibrate(
self,
train_set: Dataset,
test_set: Dataset,
batch_size: int,
epoch: int,
use_cuda: bool,
double_fit: bool = False,
**kwargs
Expand All @@ -88,8 +100,6 @@ def calibrate(
Args:
train_set (Dataset): The training set.
test_set (Dataset): The validation set.
batch_size (int): Batch size used.
epoch (int): Number of epochs to train the linear layer for.
use_cuda (bool): If "True", will use GPU.
double_fit (bool): If "True" would fit twice on the train set.
kwargs (dict): Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().
Expand All @@ -106,36 +116,16 @@ def calibrate(
if use_cuda:
self.dirichlet_linear.cuda()

optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)

loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

if double_fit:
lr = self.lr / 10
optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
self.wrapper.args.optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

Expand Down
10 changes: 5 additions & 5 deletions baal/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor

from baal import ModelWrapper
from baal.modelwrapper import _stack_preds
from baal.modelwrapper import _stack_preds, TrainingArgs
from baal.utils.cuda_utils import to_cuda


Expand All @@ -15,16 +15,16 @@ class EnsembleModelWrapper(ModelWrapper):

Args:
model (nn.Module): A Model.
criterion (Callable): Loss function
args (TrainingArgs): Argument for model

Notes:
If you're looking to use ensembles for non-deep models, see our sklearn tutorial:
baal.readthedocs.io/en/latest/notebooks/sklearn_tutorial.html
"""

def __init__(self, model, criterion):
super().__init__(model, criterion)
self._weights = []
def __init__(self, model, args: TrainingArgs):
super().__init__(model, args)
self._weights: List[Dict] = []

def add_checkpoint(self):
"""
Expand Down
Loading
Loading