Skip to content

Commit

Permalink
Add Args to ModelWrapper to simplify common API
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 18, 2024
1 parent 4171b7a commit b86d9ff
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 379 deletions.
24 changes: 5 additions & 19 deletions baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ class AbstractGPUHeuristic(ModelWrapper):
def __init__(
self,
model: ModelWrapper,
criterion,
shuffle_prop=0.0,
threshold=None,
reverse=False,
reduction="none",
):
super().__init__(model, criterion)
super().__init__(model, model.args)
self.shuffle_prop = shuffle_prop
self.threshold = threshold
self.reversed = reverse
Expand Down Expand Up @@ -94,25 +93,15 @@ def get_uncertainties(self, predictions):
def predict_on_dataset(
self,
dataset: Dataset,
batch_size: int,
iterations: int,
use_cuda: bool,
workers: int = 4,
collate_fn: Optional[Callable] = None,
half=False,
verbose=True,
):
return (
super()
.predict_on_dataset(
dataset, batch_size, iterations, use_cuda, workers, collate_fn, half, verbose
)
.reshape([-1])
)
return super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])

def predict_on_batch(self, data, iterations=1, use_cuda=False):
def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
return self.get_uncertainties(self.model.predict_on_batch(data, iterations, cuda=use_cuda))
return self.get_uncertainties(self.model.predict_on_batch(data, iterations))


class BALDGPUWrapper(AbstractGPUHeuristic):
Expand All @@ -121,12 +110,9 @@ class BALDGPUWrapper(AbstractGPUHeuristic):
https://arxiv.org/abs/1703.02910
"""

def __init__(
self, model: ModelWrapper, criterion, shuffle_prop=0.0, threshold=None, reduction="none"
):
def __init__(self, model: ModelWrapper, shuffle_prop=0.0, threshold=None, reduction="none"):
super().__init__(
model,
criterion=criterion,
shuffle_prop=shuffle_prop,
threshold=threshold,
reverse=True,
Expand Down
43 changes: 15 additions & 28 deletions baal/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.optim import Adam

from baal import ModelWrapper
from baal.modelwrapper import TrainingArgs
from baal.utils.metrics import ECE, ECE_PerCLs

log = structlog.get_logger("Calibrating...")
Expand Down Expand Up @@ -55,7 +56,17 @@ def __init__(
self.mu = mu or reg_factor
self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes)
self.model = nn.Sequential(wrapper.model, self.dirichlet_linear)
self.wrapper = ModelWrapper(self.model, self.criterion)
self.optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)
self.wrapper = ModelWrapper(
self.model,
TrainingArgs(
criterion=self.criterion,
optimizer=self.optimizer,
regularizer=self.l2_reg,
epoch=5,
use_cuda=False,
),
)

self.wrapper.add_metric("ece", lambda: ECE())
self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))
Expand All @@ -75,8 +86,6 @@ def calibrate(
self,
train_set: Dataset,
test_set: Dataset,
batch_size: int,
epoch: int,
use_cuda: bool,
double_fit: bool = False,
**kwargs
Expand All @@ -88,8 +97,6 @@ def calibrate(
Args:
train_set (Dataset): The training set.
test_set (Dataset): The validation set.
batch_size (int): Batch size used.
epoch (int): Number of epochs to train the linear layer for.
use_cuda (bool): If "True", will use GPU.
double_fit (bool): If "True" would fit twice on the train set.
kwargs (dict): Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().
Expand All @@ -106,36 +113,16 @@ def calibrate(
if use_cuda:
self.dirichlet_linear.cuda()

optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)

loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

if double_fit:
lr = self.lr / 10
optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
self.wrapper.args.optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

Expand Down
8 changes: 4 additions & 4 deletions baal/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor

from baal import ModelWrapper
from baal.modelwrapper import _stack_preds
from baal.modelwrapper import _stack_preds, TrainingArgs
from baal.utils.cuda_utils import to_cuda


Expand All @@ -15,15 +15,15 @@ class EnsembleModelWrapper(ModelWrapper):
Args:
model (nn.Module): A Model.
criterion (Callable): Loss function
args (TrainingArgs): Argument for model
Notes:
If you're looking to use ensembles for non-deep models, see our sklearn tutorial:
baal.readthedocs.io/en/latest/notebooks/sklearn_tutorial.html
"""

def __init__(self, model, criterion):
super().__init__(model, criterion)
def __init__(self, model, args: TrainingArgs):
super().__init__(model, args)
self._weights = []

def add_checkpoint(self):
Expand Down
Loading

0 comments on commit b86d9ff

Please sign in to comment.