I am currently in the process of trying to understand the paper and code. Quite a bit of the comments are currently probably wrong …
logits_B_K_C
- B
- The batch put into the system
- K
- The amount of MC samples drawn for each part of the batch
- C
- Number of classes (one neuron per class obviously)
- Train a normal pytorch model
- Determine logits of candidate samples for labeling
- Replicate samples of a batch
k
times (one for each MC sample) - Calculate the logits using a version of dropout that keeps the
same mask for all elements in a batch, but a different mask for
the
k
different MC samples- When using dropout during inference we can interpret the whole model without dropout we have trained as a probability distribution over models. We can sample from this distribution by using dropout during inference with a specific mask. Randomly sampling a mask and using the model with this mask amounts to sampling a model.
- This results in the aforementioned
logits_B_K_C
form
- Replicate samples of a batch
- Use
mutual_information(logits_B_K_C)
to get the mutual information between each of the elements of the batch (using thek
MC samples) and the model parameters. Results inscores_B
.- This score is also used to initialize a_BatchBALD for the first sample to be selected into the aquisition bag.
- Select
n
elements with the smallest mutual information (scores_B
andlogits_B_K_C
). Continue with this subset- Not sure why this is done …
- Maybe to restrict the search for samples that are promising in the first place?
- The main algorithm seems to be in
compute_multi_bald_batch
inmulti_bald.py
- Calculate
conditional_entropies_B
which are the E_p(w)[H(y_i|w)]. After summing together we get E_p(w)[H(y_1, …, y_n|w)] which is the right hand side of Equation 8 to calculate batchBALD - Calculate conditional entropy with
conditional_entropy_from_logits_B_K_C
- Calculate probabilities from
logits_B_K_C
using.log_softmax(2).exp()_
# Returns B x n x output
def forward(self, input_B: torch.Tensor, k: int):
BayesianModule.k = k
# First do the deterministic part of the network that won't change for the k samples
input_B = self.deterministic_forward_impl(input_B)
# Blow up output of deterministic forward part to be able to process k samples at the same time
mc_input_BK = BayesianModule.mc_tensor(input_B, k)
# Send the k deterministic inputs through the non-deterministic part
mc_output_BK = self.mc_forward_impl(mc_input_BK)
# Bring tensor back to correct output
mc_output_B_K = BayesianModule.unflatten_tensor(mc_output_BK, k)
return mc_output_B_K
Takes a tensor and flattens the first two dimensions. e.g.
torch.Size([2, 3, 4, 5])
becomes torch.Size([6, 4, 5])
Expands the first dimension of a tensor into two, with the size of the
second dimension determined by k
Takes a tensor and repeates all other dimensions along the first
dimension k times. E.g t = torch.Size([2, 3, 4])
becomes s =
torch.Size([6, 3, 4])
with k=3
and s[0] == s[1] == s[2] == t[0]
Normal CNN implementation using MCDroput2D
instead of normal dropout
and using mc_forward_impl
to implement the forward pass. On
forward
in BayesianModule
, the incoming tensor is blown up k
times and then sent through BayesianNet
Dropout that keeps it’s dropout mask unless it is specifically
instructed to forget it using reset_mask
. Generates a different mask
for each k
(MC sample) but repeates the same mask over all the
elements in the batch. E.g. all elements from the batch are evaluated
using the same k sampled models (via dropout)
def logit_mean(logits, dim: int, keepdim: bool = False):
r"""Computes $\log \left ( \frac{1}{n} \sum_i p_i \right ) =
\log \left ( \frac{1}{n} \sum_i e^{\log p_i} \right )$.
We pass in logits.
"""
return torch.logsumexp(logits, dim=dim, keepdim=keepdim) - math.log(logits.shape[dim])
def entropy(logits, dim: int, keepdim: bool = False):
return -torch.sum((torch.exp(logits) * logits).double(), dim=dim, keepdim=keepdim)
def mutual_information(logits_B_K_C):
"""Returns the mutual information for each element of the batch,
determined by the K MC samples"""
sample_entropies_B_K = entropy(logits_B_K_C, dim=-1)
entropy_mean_B = torch.mean(sample_entropies_B_K, dim=1)
logits_mean_B_C = logit_mean(logits_B_K_C, dim=1)
mean_entropy_B = entropy(logits_mean_B_C, dim=-1)
mutual_info_B = mean_entropy_B - entropy_mean_B
return mutual_info_B