1
1
import torch
2
2
from torch import nn
3
3
from torch .nn .modules .loss import _Loss
4
+ from super_gradients .training .exceptions .loss_exceptions import IllegalRangeForLossAttributeException , RequiredLossComponentReductionException
4
5
5
6
6
- class OhemCELoss (_Loss ):
7
+ class OhemLoss (_Loss ):
7
8
"""
8
- OhemCELoss - Online Hard Example Mining Cross Entropy Loss
9
+ OhemLoss - Online Hard Example Mining Cross Entropy Loss
9
10
"""
11
+
10
12
def __init__ (self ,
11
13
threshold : float ,
12
14
mining_percent : float = 0.1 ,
13
15
ignore_lb : int = - 100 ,
14
- num_pixels_exclude_ignored : bool = True ):
16
+ num_pixels_exclude_ignored : bool = True ,
17
+ criteria : _Loss = None ):
15
18
"""
16
19
:param threshold: Sample below probability threshold, is considered hard.
17
20
:param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the
18
21
samples.
22
+ :param ignore_lb: label index to be ignored in loss calculation.
23
+ :param criteria: loss to mine the examples from.
24
+
19
25
i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:
20
26
num_pixels_exclude_ignored=False => num_mining = 100 * 0.1 = 10
21
27
num_pixels_exclude_ignored=True => num_mining = (100 - 30) * 0.1 = 7
22
28
"""
23
29
super ().__init__ ()
24
- assert 0 <= mining_percent <= 1 , "mining percent should be a value from 0 to 1"
30
+
31
+ if mining_percent < 0 or mining_percent > 1 :
32
+ raise IllegalRangeForLossAttributeException ((0 , 1 ), "mining percent" )
33
+
25
34
self .thresh = - torch .log (torch .tensor (threshold , dtype = torch .float ))
26
35
self .mining_percent = mining_percent
27
- self .ignore_lb = - 100 if ignore_lb is None or ignore_lb < 0 else ignore_lb
36
+ self .ignore_lb = ignore_lb
28
37
self .num_pixels_exclude_ignored = num_pixels_exclude_ignored
29
- self .criteria = nn .CrossEntropyLoss (ignore_index = self .ignore_lb , reduction = 'none' )
38
+
39
+ if criteria .reduction != 'none' :
40
+ raise RequiredLossComponentReductionException ("criteria" , criteria .reduction , 'none' )
41
+ self .criteria = criteria
30
42
31
43
def forward (self , logits , labels ):
32
44
loss = self .criteria (logits , labels ).view (- 1 )
@@ -52,3 +64,46 @@ def forward(self, logits, labels):
52
64
else :
53
65
loss = loss [:num_mining ]
54
66
return torch .mean (loss )
67
+
68
+
69
+ class OhemCELoss (OhemLoss ):
70
+ """
71
+ OhemLoss - Online Hard Example Mining Cross Entropy Loss
72
+ """
73
+
74
+ def __init__ (self ,
75
+ threshold : float ,
76
+ mining_percent : float = 0.1 ,
77
+ ignore_lb : int = - 100 ,
78
+ num_pixels_exclude_ignored : bool = True ):
79
+ ignore_lb = - 100 if ignore_lb is None or ignore_lb < 0 else ignore_lb
80
+ criteria = nn .CrossEntropyLoss (ignore_index = ignore_lb , reduction = 'none' )
81
+ super (OhemCELoss , self ).__init__ (threshold = threshold ,
82
+ mining_percent = mining_percent ,
83
+ ignore_lb = ignore_lb ,
84
+ num_pixels_exclude_ignored = num_pixels_exclude_ignored ,
85
+ criteria = criteria )
86
+
87
+
88
+ class OhemBCELoss (OhemLoss ):
89
+ """
90
+ OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss
91
+ """
92
+
93
+ def __init__ (self ,
94
+ threshold : float ,
95
+ mining_percent : float = 0.1 ,
96
+ ignore_lb : int = - 100 ,
97
+ num_pixels_exclude_ignored : bool = True , ):
98
+ super (OhemBCELoss , self ).__init__ (threshold = threshold ,
99
+ mining_percent = mining_percent ,
100
+ ignore_lb = ignore_lb ,
101
+ num_pixels_exclude_ignored = num_pixels_exclude_ignored ,
102
+ criteria = nn .BCEWithLogitsLoss (reduction = 'none' ))
103
+
104
+ def forward (self , logits , labels ):
105
+
106
+ # REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA
107
+ if logits .shape [1 ] == 1 :
108
+ logits = logits .squeeze (1 )
109
+ return super (OhemBCELoss , self ).forward (logits , labels .float ())
0 commit comments