Skip to content

Commit 6acb183

Browse files
committed
stdc bugfix for 1.7.5
1 parent ece85e9 commit 6acb183

File tree

5 files changed

+106
-12
lines changed

5 files changed

+106
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class IllegalRangeForLossAttributeException(Exception):
2+
"""
3+
Exception raised illegal value (i.e not in range) for _Loss attribute.
4+
5+
Attributes:
6+
message -- explanation of the error
7+
"""
8+
9+
def __init__(self, range_vals: tuple, attr_name: str):
10+
self.message = attr_name + " must be in range " + str(range_vals)
11+
super().__init__(self.message)
12+
13+
14+
class RequiredLossComponentReductionException(Exception):
15+
"""
16+
Exception raised illegal reduction for _Loss component.
17+
18+
Attributes:
19+
message -- explanation of the error
20+
"""
21+
22+
def __init__(self, component_name: str, reduction: str, required_reduction: str):
23+
self.message = component_name + ".reduction must be " + required_reduction + ", got" + reduction
24+
super().__init__(self.message)

src/super_gradients/training/losses/ohem_ce_loss.py

+61-6
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
11
import torch
22
from torch import nn
33
from torch.nn.modules.loss import _Loss
4+
from super_gradients.training.exceptions.loss_exceptions import IllegalRangeForLossAttributeException, RequiredLossComponentReductionException
45

56

6-
class OhemCELoss(_Loss):
7+
class OhemLoss(_Loss):
78
"""
8-
OhemCELoss - Online Hard Example Mining Cross Entropy Loss
9+
OhemLoss - Online Hard Example Mining Cross Entropy Loss
910
"""
11+
1012
def __init__(self,
1113
threshold: float,
1214
mining_percent: float = 0.1,
1315
ignore_lb: int = -100,
14-
num_pixels_exclude_ignored: bool = True):
16+
num_pixels_exclude_ignored: bool = True,
17+
criteria: _Loss = None):
1518
"""
1619
:param threshold: Sample below probability threshold, is considered hard.
1720
:param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the
1821
samples.
22+
:param ignore_lb: label index to be ignored in loss calculation.
23+
:param criteria: loss to mine the examples from.
24+
1925
i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:
2026
num_pixels_exclude_ignored=False => num_mining = 100 * 0.1 = 10
2127
num_pixels_exclude_ignored=True => num_mining = (100 - 30) * 0.1 = 7
2228
"""
2329
super().__init__()
24-
assert 0 <= mining_percent <= 1, "mining percent should be a value from 0 to 1"
30+
31+
if mining_percent < 0 or mining_percent > 1:
32+
raise IllegalRangeForLossAttributeException((0, 1), "mining percent")
33+
2534
self.thresh = -torch.log(torch.tensor(threshold, dtype=torch.float))
2635
self.mining_percent = mining_percent
27-
self.ignore_lb = -100 if ignore_lb is None or ignore_lb < 0 else ignore_lb
36+
self.ignore_lb = ignore_lb
2837
self.num_pixels_exclude_ignored = num_pixels_exclude_ignored
29-
self.criteria = nn.CrossEntropyLoss(ignore_index=self.ignore_lb, reduction='none')
38+
39+
if criteria.reduction != 'none':
40+
raise RequiredLossComponentReductionException("criteria", criteria.reduction, 'none')
41+
self.criteria = criteria
3042

3143
def forward(self, logits, labels):
3244
loss = self.criteria(logits, labels).view(-1)
@@ -52,3 +64,46 @@ def forward(self, logits, labels):
5264
else:
5365
loss = loss[:num_mining]
5466
return torch.mean(loss)
67+
68+
69+
class OhemCELoss(OhemLoss):
70+
"""
71+
OhemLoss - Online Hard Example Mining Cross Entropy Loss
72+
"""
73+
74+
def __init__(self,
75+
threshold: float,
76+
mining_percent: float = 0.1,
77+
ignore_lb: int = -100,
78+
num_pixels_exclude_ignored: bool = True):
79+
ignore_lb = -100 if ignore_lb is None or ignore_lb < 0 else ignore_lb
80+
criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
81+
super(OhemCELoss, self).__init__(threshold=threshold,
82+
mining_percent=mining_percent,
83+
ignore_lb=ignore_lb,
84+
num_pixels_exclude_ignored=num_pixels_exclude_ignored,
85+
criteria=criteria)
86+
87+
88+
class OhemBCELoss(OhemLoss):
89+
"""
90+
OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss
91+
"""
92+
93+
def __init__(self,
94+
threshold: float,
95+
mining_percent: float = 0.1,
96+
ignore_lb: int = -100,
97+
num_pixels_exclude_ignored: bool = True, ):
98+
super(OhemBCELoss, self).__init__(threshold=threshold,
99+
mining_percent=mining_percent,
100+
ignore_lb=ignore_lb,
101+
num_pixels_exclude_ignored=num_pixels_exclude_ignored,
102+
criteria=nn.BCEWithLogitsLoss(reduction='none'))
103+
104+
def forward(self, logits, labels):
105+
106+
# REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA
107+
if logits.shape[1] == 1:
108+
logits = logits.squeeze(1)
109+
return super(OhemBCELoss, self).forward(logits, labels.float())

src/super_gradients/training/losses/stdc_loss.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44
from super_gradients.training.utils.segmentation_utils import to_one_hot
55
from torch.nn.modules.loss import _Loss
6-
from super_gradients.training.losses.ohem_ce_loss import OhemCELoss
6+
from super_gradients.training.losses.ohem_ce_loss import OhemCELoss, OhemBCELoss, OhemLoss
77
from super_gradients.training.losses.dice_loss import BinaryDiceLoss
88
from typing import Union, Tuple
99

@@ -46,8 +46,10 @@ def __init__(self,
4646
def forward(self, gt_masks: torch.Tensor):
4747
if self.device is None:
4848
self._set_kernels_to_device(gt_masks.device)
49-
50-
one_hot = to_one_hot(gt_masks, self.num_classes, self.ignore_label).float()
49+
if self.num_classes > 1:
50+
one_hot = to_one_hot(gt_masks, self.num_classes, self.ignore_label).float()
51+
else:
52+
one_hot = gt_masks.unsqueeze(1).float()
5153
# create binary detail maps using filters withs strides of 1, 2 and 4.
5254
boundary_targets = F.conv2d(one_hot, self.laplacian_kernel, stride=1, padding=1, groups=self.num_classes)
5355
boundary_targets_x2 = F.conv2d(one_hot, self.laplacian_kernel, stride=2, padding=1, groups=self.num_classes)
@@ -123,7 +125,8 @@ def __init__(self,
123125
mining_percent: float = 0.1,
124126
detail_threshold: float = 1.,
125127
learnable_fusing_kernel: bool = True,
126-
ignore_index: int = None):
128+
ignore_index: int = None,
129+
ohem_criteria: OhemLoss = None):
127130
"""
128131
:param threshold: Online hard-mining probability threshold.
129132
:param num_aux_heads: num of auxiliary heads.
@@ -133,6 +136,8 @@ def __init__(self,
133136
:param mining_percent: mining percentage.
134137
:param detail_threshold: detail threshold to create binary details features in DetailLoss.
135138
:param learnable_fusing_kernel: whether DetailAggregateModule params are learnable or not.
139+
:param ohem_criteria: OhemLoss criterion component of STDC. When none is given, it will be derrived according
140+
to num_classes (i.e OhemCELoss if num_classes > 1 and OhemBCELoss otherwise).
136141
"""
137142
super().__init__()
138143

@@ -151,7 +156,11 @@ def __init__(self,
151156
learnable_fusing_kernel=learnable_fusing_kernel)
152157
self.detail_loss = DetailLoss(weights=detail_weights)
153158

154-
self.ce_ohem = OhemCELoss(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_index)
159+
if ohem_criteria is None:
160+
ohem_criteria = OhemCELoss(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_index) if num_classes > 1 else OhemBCELoss(threshold=threshold, mining_percent=mining_percent)
161+
162+
self.ce_ohem = ohem_criteria
163+
self.num_classes = num_classes
155164

156165
def forward(self, preds: Tuple[torch.Tensor], target: torch.Tensor):
157166
"""

src/super_gradients/training/metrics/segmentation_metrics.py

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def __init__(self, dist_sync_on_step=True, ignore_index=None):
119119
self.component_names = ["target_IOU", "background_IOU", "mean_IOU"]
120120

121121
def update(self, preds, target: torch.Tensor):
122+
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
123+
if isinstance(preds, tuple):
124+
preds = preds[0]
122125
super().update(preds=torch.sigmoid(preds), target=target.long())
123126

124127
def compute(self):

src/super_gradients/training/utils/callbacks.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,10 @@ def __init__(self, phase: Phase, freq: int, batch_idx: int = 0, last_img_idx_in_
494494

495495
def __call__(self, context: PhaseContext):
496496
if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
497-
preds = context.preds.clone()
497+
if isinstance(context.preds, tuple):
498+
preds = context.preds[0].clone()
499+
else:
500+
preds = context.preds.clone()
498501
batch_imgs = BinarySegmentationVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx)
499502
batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
500503
batch_imgs = np.stack(batch_imgs)

0 commit comments

Comments
 (0)