diff --git a/balanced_loss/losses.py b/balanced_loss/losses.py index 9e1bf50..3822c59 100644 --- a/balanced_loss/losses.py +++ b/balanced_loss/losses.py @@ -97,6 +97,8 @@ def forward( if self.class_balanced: effective_num = 1.0 - np.power(self.beta, self.samples_per_class) + # Avoid division by 0 error for test cases without all labels present. + effective_num[effective_num == 0] = 1 weights = (1.0 - self.beta) / np.array(effective_num) weights = weights / np.sum(weights) * num_classes weights = torch.tensor(weights, device=logits.device).float()