Skip to content

Commit

Permalink
Fix division by zero error for labels with no samples
Browse files Browse the repository at this point in the history
  • Loading branch information
chbeltz committed Nov 22, 2024
1 parent 1c71f91 commit 98b2612
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions balanced_loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def forward(

if self.class_balanced:
effective_num = 1.0 - np.power(self.beta, self.samples_per_class)
# Avoid division by 0 error for test cases without all labels present.
effective_num[effective_num == 0] = 1
weights = (1.0 - self.beta) / np.array(effective_num)
weights = weights / np.sum(weights) * num_classes
weights = torch.tensor(weights, device=logits.device).float()
Expand Down

0 comments on commit 98b2612

Please sign in to comment.