From f137f8a89106829f2a509de5ac9df5f713ea72ed Mon Sep 17 00:00:00 2001 From: Christopher Beltz Date: Fri, 22 Nov 2024 18:23:46 +0100 Subject: [PATCH] safe switch for equivalent weight calculation in the presence of 0-sample labels --- balanced_loss/losses.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/balanced_loss/losses.py b/balanced_loss/losses.py index 3822c59..5d7bc83 100644 --- a/balanced_loss/losses.py +++ b/balanced_loss/losses.py @@ -44,6 +44,7 @@ def __init__( fl_gamma=2, samples_per_class=None, class_balanced=False, + safe: bool = False ): """ Compute the Class Balanced Loss between `logits` and the ground truth `labels`. @@ -60,6 +61,7 @@ def __init__( samples_per_class: A python list of size [num_classes]. Required if class_balance is True. class_balanced: bool. Whether to use class balanced loss. + safe: bool. Whether to allow labels with no samples. Returns: Loss instance """ @@ -73,11 +75,12 @@ def __init__( self.fl_gamma = fl_gamma self.samples_per_class = samples_per_class self.class_balanced = class_balanced + self.safe = safe def forward( self, logits: torch.tensor, - labels: torch.tensor, + labels: torch.tensor ): """ Compute the Class Balanced Loss between `logits` and the ground truth `labels`. @@ -98,9 +101,15 @@ def forward( if self.class_balanced: effective_num = 1.0 - np.power(self.beta, self.samples_per_class) # Avoid division by 0 error for test cases without all labels present. - effective_num[effective_num == 0] = 1 + if self.safe: + effective_num_classes = np.sum(effective_num != 0) + effective_num[effective_num == 0] = np.inf + + else: + effective_num_classes = num_classes + weights = (1.0 - self.beta) / np.array(effective_num) - weights = weights / np.sum(weights) * num_classes + weights = weights / np.sum(weights) * effective_num_classes weights = torch.tensor(weights, device=logits.device).float() if self.loss_type != "cross_entropy":