Skip to content
This repository has been archived by the owner on Jan 28, 2025. It is now read-only.

Latest commit

 

History

History
103 lines (94 loc) · 4.88 KB

README.org

File metadata and controls

103 lines (94 loc) · 4.88 KB

Warning

I am currently in the process of trying to understand the paper and code. Quite a bit of the comments are currently probably wrong …

Infos

  • logits_B_K_C
    B
    The batch put into the system
    K
    The amount of MC samples drawn for each part of the batch
    C
    Number of classes (one neuron per class obviously)

General process

  • Train a normal pytorch model
  • Determine logits of candidate samples for labeling
    • Replicate samples of a batch k times (one for each MC sample)
    • Calculate the logits using a version of dropout that keeps the same mask for all elements in a batch, but a different mask for the k different MC samples
      • When using dropout during inference we can interpret the whole model without dropout we have trained as a probability distribution over models. We can sample from this distribution by using dropout during inference with a specific mask. Randomly sampling a mask and using the model with this mask amounts to sampling a model.
    • This results in the aforementioned logits_B_K_C form
  • Use mutual_information(logits_B_K_C) to get the mutual information between each of the elements of the batch (using the k MC samples) and the model parameters. Results in scores_B.
    • This score is also used to initialize a_BatchBALD for the first sample to be selected into the aquisition bag.
  • Select n elements with the smallest mutual information (scores_B and logits_B_K_C). Continue with this subset
    • Not sure why this is done …
    • Maybe to restrict the search for samples that are promising in the first place?
  • The main algorithm seems to be in compute_multi_bald_batch in multi_bald.py
  • Calculate conditional_entropies_B which are the E_p(w)[H(y_i|w)]. After summing together we get E_p(w)[H(y_1, …, y_n|w)] which is the right hand side of Equation 8 to calculate batchBALD
  • Calculate conditional entropy with conditional_entropy_from_logits_B_K_C
  • Calculate probabilities from logits_B_K_C using .log_softmax(2).exp()_

Classes

BayesianModule

forward

# Returns B x n x output
def forward(self, input_B: torch.Tensor, k: int):
    BayesianModule.k = k

    # First do the deterministic part of the network that won't change for the k samples
    input_B = self.deterministic_forward_impl(input_B)
    # Blow up output of deterministic forward part to be able to process k samples at the same time
    mc_input_BK = BayesianModule.mc_tensor(input_B, k)
    # Send the k deterministic inputs through the non-deterministic part
    mc_output_BK = self.mc_forward_impl(mc_input_BK)
    # Bring tensor back to correct output
    mc_output_B_K = BayesianModule.unflatten_tensor(mc_output_BK, k)
    return mc_output_B_K

flatten_tensor(mc_input: torch.Tensor)

Takes a tensor and flattens the first two dimensions. e.g. torch.Size([2, 3, 4, 5]) becomes torch.Size([6, 4, 5])

unflatten_tensor(input: torch.Tensor, k: int)

Expands the first dimension of a tensor into two, with the size of the second dimension determined by k

mc_tensor(input: torch.tensor, k: int)

Takes a tensor and repeates all other dimensions along the first dimension k times. E.g t = torch.Size([2, 3, 4]) becomes s = torch.Size([6, 3, 4]) with k=3 and s[0] == s[1] == s[2] == t[0]

BayesianNet(BayesianModule)

Normal CNN implementation using MCDroput2D instead of normal dropout and using mc_forward_impl to implement the forward pass. On forward in BayesianModule, the incoming tensor is blown up k times and then sent through BayesianNet

__MCDropout(torch.Module)

Dropout that keeps it’s dropout mask unless it is specifically instructed to forget it using reset_mask. Generates a different mask for each k (MC sample) but repeates the same mask over all the elements in the batch. E.g. all elements from the batch are evaluated using the same k sampled models (via dropout)

Mutual Information

def logit_mean(logits, dim: int, keepdim: bool = False):
    r"""Computes $\log \left ( \frac{1}{n} \sum_i p_i \right ) =
    \log \left ( \frac{1}{n} \sum_i e^{\log p_i} \right )$.

    We pass in logits.
    """
    return torch.logsumexp(logits, dim=dim, keepdim=keepdim) - math.log(logits.shape[dim])

def entropy(logits, dim: int, keepdim: bool = False):
    return -torch.sum((torch.exp(logits) * logits).double(), dim=dim, keepdim=keepdim)

def mutual_information(logits_B_K_C):
    """Returns the mutual information for each element of the batch,
    determined by the K MC samples"""
    sample_entropies_B_K = entropy(logits_B_K_C, dim=-1)
    entropy_mean_B = torch.mean(sample_entropies_B_K, dim=1)

    logits_mean_B_C = logit_mean(logits_B_K_C, dim=1)
    mean_entropy_B = entropy(logits_mean_B_C, dim=-1)

    mutual_info_B = mean_entropy_B - entropy_mean_B
    return mutual_info_B