Skip to content

Commit

Permalink
safe switch for equivalent weight calculation in the presence of 0-sa…
Browse files Browse the repository at this point in the history
…mple labels
  • Loading branch information
chbeltz committed Nov 22, 2024
1 parent 98b2612 commit f137f8a
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions balanced_loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
fl_gamma=2,
samples_per_class=None,
class_balanced=False,
safe: bool = False
):
"""
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Expand All @@ -60,6 +61,7 @@ def __init__(
samples_per_class: A python list of size [num_classes].
Required if class_balance is True.
class_balanced: bool. Whether to use class balanced loss.
safe: bool. Whether to allow labels with no samples.
Returns:
Loss instance
"""
Expand All @@ -73,11 +75,12 @@ def __init__(
self.fl_gamma = fl_gamma
self.samples_per_class = samples_per_class
self.class_balanced = class_balanced
self.safe = safe

def forward(
self,
logits: torch.tensor,
labels: torch.tensor,
labels: torch.tensor
):
"""
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Expand All @@ -98,9 +101,15 @@ def forward(
if self.class_balanced:
effective_num = 1.0 - np.power(self.beta, self.samples_per_class)
# Avoid division by 0 error for test cases without all labels present.
effective_num[effective_num == 0] = 1
if self.safe:
effective_num_classes = np.sum(effective_num != 0)
effective_num[effective_num == 0] = np.inf

else:
effective_num_classes = num_classes

weights = (1.0 - self.beta) / np.array(effective_num)
weights = weights / np.sum(weights) * num_classes
weights = weights / np.sum(weights) * effective_num_classes
weights = torch.tensor(weights, device=logits.device).float()

if self.loss_type != "cross_entropy":
Expand Down

0 comments on commit f137f8a

Please sign in to comment.